From c37c64cac916e91b270fcb707214075643e07ed1 Mon Sep 17 00:00:00 2001 From: hzphzp Date: Mon, 21 Aug 2023 16:38:25 +0800 Subject: [PATCH] upload AFFNet code --- Adaptive Frequency Filters/README.md | 3 +- Adaptive Frequency Filters/affnet/__init__.py | 48 + .../affnet/anchor_generator/__init__.py | 86 + .../anchor_generator/base_anchor_generator.py | 99 + .../anchor_generator/ssd_anchor_generator.py | 199 ++ .../image_projection_layers/__init__.py | 98 + .../attention_pool_2d.py | 190 ++ .../base_image_projection.py | 64 + .../image_projection_layers/global_pool_2d.py | 69 + .../simple_projection_head.py | 62 + .../affnet/layers/__init__.py | 101 + .../affnet/layers/activation/__init__.py | 107 + .../affnet/layers/activation/gelu.py | 25 + .../affnet/layers/activation/hard_sigmoid.py | 32 + .../affnet/layers/activation/hard_swish.py | 33 + .../affnet/layers/activation/leaky_relu.py | 32 + .../affnet/layers/activation/prelu.py | 29 + .../affnet/layers/activation/relu.py | 23 + .../affnet/layers/activation/relu6.py | 23 + .../affnet/layers/activation/sigmoid.py | 23 + .../affnet/layers/activation/swish.py | 23 + .../affnet/layers/activation/tanh.py | 23 + .../affnet/layers/adaptive_pool.py | 32 + .../affnet/layers/base_layer.py | 32 + .../affnet/layers/conv_layer.py | 802 ++++++ .../affnet/layers/conv_layer_complex.py | 281 +++ .../affnet/layers/dropout.py | 57 + .../affnet/layers/embedding.py | 57 + .../affnet/layers/flatten.py | 31 + .../affnet/layers/global_pool.py | 88 + .../affnet/layers/identity.py | 25 + .../affnet/layers/linear_attention.py | 233 ++ .../affnet/layers/linear_layer.py | 253 ++ .../affnet/layers/multi_head_attention.py | 333 +++ .../affnet/layers/non_linear_layers.py | 77 + .../affnet/layers/normalization/__init__.py | 143 ++ .../affnet/layers/normalization/batch_norm.py | 173 ++ .../affnet/layers/normalization/group_norm.py | 50 + .../layers/normalization/instance_norm.py | 95 + .../affnet/layers/normalization/layer_norm.py | 145 ++ .../layers/normalization/sync_batch_norm.py | 88 + .../affnet/layers/normalization_layers.py | 148 ++ .../affnet/layers/pixel_shuffle.py | 34 + .../affnet/layers/pooling.py | 91 + .../affnet/layers/positional_embedding.py | 189 ++ .../affnet/layers/positional_encoding.py | 162 ++ .../affnet/layers/random_layers.py | 61 + .../affnet/layers/single_head_attention.py | 154 ++ .../affnet/layers/softmax.py | 27 + .../affnet/layers/stocastic_depth.py | 23 + .../affnet/layers/upsample.py | 44 + .../affnet/matcher_det/__init__.py | 79 + .../affnet/matcher_det/base_matcher.py | 25 + .../affnet/matcher_det/ssd_matcher.py | 160 ++ .../affnet/misc/__init__.py | 0 .../affnet/misc/averaging_utils.py | 74 + .../affnet/misc/box_utils.py | 119 + .../affnet/misc/common.py | 166 ++ .../affnet/misc/init_utils.py | 151 ++ .../affnet/misc/profiler.py | 33 + .../affnet/misc/third_party/__init__.py | 0 .../affnet/misc/third_party/ssd_utils.py | 126 + .../affnet/models/__init__.py | 109 + .../affnet/models/classification/__init__.py | 146 ++ .../affnet/models/classification/affnet.py | 304 +++ .../affnet/models/classification/base_cls.py | 459 ++++ .../models/classification/config/__init__.py | 0 .../models/classification/config/affnet.py | 329 +++ .../affnet/models/detection/__init__.py | 139 ++ .../affnet/models/detection/base_detection.py | 146 ++ .../affnet/models/detection/mask_rcnn.py | 863 +++++++ .../affnet/models/detection/ssd.py | 686 ++++++ .../affnet/models/detection/utils/__init__.py | 0 .../models/detection/utils/rcnn_utils.py | 264 ++ .../affnet/models/segmentation/__init__.py | 149 ++ .../affnet/models/segmentation/base_seg.py | 184 ++ .../affnet/models/segmentation/enc_dec.py | 198 ++ .../models/segmentation/heads/__init__.py | 72 + .../segmentation/heads/base_seg_head.py | 175 ++ .../models/segmentation/heads/deeplabv3.py | 156 ++ .../models/segmentation/heads/pspnet.py | 147 ++ .../affnet/modules/__init__.py | 27 + .../affnet/modules/aff_block.py | 541 +++++ .../affnet/modules/aspp_block.py | 267 ++ .../affnet/modules/base_module.py | 25 + .../affnet/modules/cbam.py | 211 ++ .../affnet/modules/complexFunctions.py | 133 + .../affnet/modules/complexLayers.py | 466 ++++ .../affnet/modules/efficientnet.py | 52 + .../affnet/modules/feature_pyramid.py | 175 ++ .../affnet/modules/mobilenetv2.py | 257 ++ .../affnet/modules/mobilevit_block.py | 724 ++++++ .../affnet/modules/pspnet_module.py | 135 ++ .../affnet/modules/resnet_modules.py | 265 ++ .../affnet/modules/squeeze_excitation.py | 90 + .../affnet/modules/ssd_heads.py | 263 ++ .../affnet/modules/swin_transformer_block.py | 429 ++++ .../affnet/modules/transformer.py | 299 +++ .../affnet/neural_augmentor/__init__.py | 15 + .../affnet/neural_augmentor/neural_aug.py | 320 +++ .../affnet/neural_augmentor/utils/__init__.py | 0 .../utils/neural_aug_utils.py | 141 ++ .../affnet/text_encoders/__init__.py | 92 + .../affnet/text_encoders/base_text_encoder.py | 111 + .../affnet/text_encoders/transformer.py | 515 ++++ Adaptive Frequency Filters/common/__init__.py | 25 + Adaptive Frequency Filters/data/__init__.py | 7 + .../data/collate_fns/__init__.py | 95 + .../data/collate_fns/collate_functions.py | 43 + .../data/data_loaders.py | 138 ++ .../data/datasets/__init__.py | 292 +++ .../data/datasets/classification/__init__.py | 9 + .../data/datasets/classification/imagenet.py | 221 ++ .../datasets/classification/imagenet_fast.py | 198 ++ .../classification/imagenet_opencv.py | 162 ++ .../imagenet_opencv_bitplane_fast.py | 159 ++ .../classification/imagenet_opencv_fast.py | 161 ++ .../datasets/classification/imagenet_v2.py | 174 ++ .../data/datasets/dataset_base.py | 231 ++ .../data/datasets/detection/__init__.py | 0 .../data/datasets/detection/coco_base.py | 343 +++ .../data/datasets/detection/coco_mask_rcnn.py | 151 ++ .../data/datasets/detection/coco_ssd.py | 225 ++ .../datasets/multi_modal_img_text/__init__.py | 41 + .../base_multi_modal_img_text.py | 419 ++++ .../img_text_tar_dataset.py | 411 ++++ .../zero_shot/__init__.py | 95 + .../zero_shot/base_zero_shot.py | 47 + .../zero_shot/imagenet.py | 1143 +++++++++ .../data/datasets/segmentation/__init__.py | 0 .../data/datasets/segmentation/ade20k.py | 522 ++++ .../segmentation/coco_segmentation.py | 231 ++ .../data/datasets/segmentation/pascal_voc.py | 276 +++ .../datasets/video_classification/__init__.py | 0 .../datasets/video_classification/kinetics.py | 287 +++ .../data/loader/__init__.py | 0 .../data/loader/dataloader.py | 55 + .../data/sampler/__init__.py | 117 + .../data/sampler/base_sampler.py | 296 +++ .../data/sampler/batch_sampler.py | 156 ++ .../data/sampler/multi_scale_sampler.py | 340 +++ .../data/sampler/utils.py | 125 + .../data/sampler/variable_batch_sampler.py | 422 ++++ .../data/sampler/video_batch_sampler.py | 163 ++ .../sampler/video_variable_seq_sampler.py | 318 +++ .../data/text_tokenizer/__init__.py | 92 + .../data/text_tokenizer/base_tokenizer.py | 45 + .../data/text_tokenizer/clip_tokenizer.py | 88 + .../data/transforms/__init__.py | 57 + .../data/transforms/base_transforms.py | 27 + .../data/transforms/image_opencv.py | 1761 ++++++++++++++ .../data/transforms/image_pil.py | 2159 +++++++++++++++++ .../data/transforms/image_torch.py | 248 ++ .../data/transforms/utils.py | 48 + .../data/transforms/video.py | 609 +++++ .../data/video_reader/__init__.py | 115 + .../data/video_reader/base_video_reader.py | 234 ++ .../data/video_reader/default_video_reader.py | 132 + .../data/video_reader/key_frame_reader.py | 116 + Adaptive Frequency Filters/engine/__init__.py | 8 + .../engine/detection_utils/__init__.py | 0 .../engine/detection_utils/coco_map.py | 110 + .../engine/eval_detection.py | 410 ++++ .../engine/eval_segmentation.py | 500 ++++ .../engine/evaluation_engine.py | 204 ++ .../engine/segmentation_utils/__init__.py | 0 .../segmentation_utils/cityscapes_iou.py | 42 + .../engine/training_engine.py | 1004 ++++++++ Adaptive Frequency Filters/engine/utils.py | 172 ++ .../loss_fn/__init__.py | 120 + .../loss_fn/base_criteria.py | 51 + .../loss_fn/base_neural_aug.py | 220 ++ .../loss_fn/classification.py | 48 + .../classification_loss_fns/__init__.py | 82 + .../binary_cross_entropy.py | 36 + .../classification_loss_fns/cross_entropy.py | 75 + .../cross_entropy_with_neural_aug.py | 94 + .../loss_fn/detection.py | 56 + .../loss_fn/detection_loss_fns/__init__.py | 76 + .../detection_loss_fns/mask_rcnn_loss.py | 105 + .../mask_rcnn_loss_with_neural_aug.py | 87 + .../detection_loss_fns/ssd_multibox_loss.py | 190 ++ .../loss_fn/detection_loss_fns/utils.py | 58 + .../loss_fn/distillation.py | 53 + .../loss_fn/distillation_loss_fns/__init__.py | 68 + .../distillation_loss_fns/cls_kl_div_loss.py | 159 ++ .../cls_kl_div_loss_neural_aug.py | 104 + .../loss_fn/distillation_loss_fns/utils.py | 37 + .../loss_fn/multi_modal_img_text.py | 47 + .../multi_modal_img_text_loss_fns/__init__.py | 81 + .../contrastive_loss_clip.py | 102 + .../contrastive_loss_clip_with_neural_aug.py | 67 + .../loss_fn/segmentation.py | 48 + .../loss_fn/segmentation_loss_fns/__init__.py | 78 + .../segmentation_loss_fns/cross_entropy.py | 131 + .../seg_cross_entropy_with_neural_aug.py | 108 + .../loss_landscape/__init__.py | 0 .../loss_landscape/landscape_utils.py | 135 ++ Adaptive Frequency Filters/main_eval.py | 161 ++ Adaptive Frequency Filters/main_train.py | 302 +++ .../metrics/__init__.py | 74 + .../metrics/coco_map.py | 245 ++ .../metrics/confusion_mat.py | 43 + .../metrics/intersection_over_union.py | 50 + .../metrics/metric_monitor.py | 315 +++ Adaptive Frequency Filters/metrics/psnr.py | 29 + Adaptive Frequency Filters/metrics/stats.py | 234 ++ .../metrics/topk_accuracy.py | 30 + Adaptive Frequency Filters/optim/__init__.py | 169 ++ Adaptive Frequency Filters/optim/adam.py | 69 + Adaptive Frequency Filters/optim/adamw.py | 69 + .../optim/base_optim.py | 20 + .../optim/scheduler/__init__.py | 118 + .../optim/scheduler/base_scheduler.py | 58 + .../optim/scheduler/cosine.py | 93 + .../optim/scheduler/cyclic.py | 180 ++ .../optim/scheduler/fixed.py | 69 + .../optim/scheduler/multi_step.py | 95 + .../optim/scheduler/polynomial.py | 94 + Adaptive Frequency Filters/optim/sgd.py | 61 + .../options/__init__.py | 0 Adaptive Frequency Filters/options/opts.py | 522 ++++ .../options/parse_args.py | 43 + Adaptive Frequency Filters/options/utils.py | 119 + Adaptive Frequency Filters/requirements.txt | 56 + .../requirements_docs.txt | 5 + Adaptive Frequency Filters/utils/__init__.py | 0 .../utils/checkpoint_utils.py | 314 +++ Adaptive Frequency Filters/utils/color_map.py | 62 + .../utils/common_utils.py | 128 + Adaptive Frequency Filters/utils/ddp_utils.py | 89 + .../utils/download_utils.py | 15 + .../utils/download_utils_base.py | 89 + Adaptive Frequency Filters/utils/logger.py | 127 + .../utils/math_utils.py | 37 + .../utils/my_dataset_folder.py | 100 + .../utils/pytorch_to_coreml.py | 120 + .../utils/tensor_utils.py | 157 ++ .../utils/third_party/__init__.py | 0 .../utils/third_party/ddp_functional_utils.py | 466 ++++ .../utils/visualization_utils.py | 134 + 241 files changed, 40286 insertions(+), 2 deletions(-) create mode 100644 Adaptive Frequency Filters/affnet/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/anchor_generator/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/anchor_generator/base_anchor_generator.py create mode 100644 Adaptive Frequency Filters/affnet/anchor_generator/ssd_anchor_generator.py create mode 100644 Adaptive Frequency Filters/affnet/image_projection_layers/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/image_projection_layers/attention_pool_2d.py create mode 100644 Adaptive Frequency Filters/affnet/image_projection_layers/base_image_projection.py create mode 100644 Adaptive Frequency Filters/affnet/image_projection_layers/global_pool_2d.py create mode 100644 Adaptive Frequency Filters/affnet/image_projection_layers/simple_projection_head.py create mode 100644 Adaptive Frequency Filters/affnet/layers/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/layers/activation/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/layers/activation/gelu.py create mode 100644 Adaptive Frequency Filters/affnet/layers/activation/hard_sigmoid.py create mode 100644 Adaptive Frequency Filters/affnet/layers/activation/hard_swish.py create mode 100644 Adaptive Frequency Filters/affnet/layers/activation/leaky_relu.py create mode 100644 Adaptive Frequency Filters/affnet/layers/activation/prelu.py create mode 100644 Adaptive Frequency Filters/affnet/layers/activation/relu.py create mode 100644 Adaptive Frequency Filters/affnet/layers/activation/relu6.py create mode 100644 Adaptive Frequency Filters/affnet/layers/activation/sigmoid.py create mode 100644 Adaptive Frequency Filters/affnet/layers/activation/swish.py create mode 100644 Adaptive Frequency Filters/affnet/layers/activation/tanh.py create mode 100644 Adaptive Frequency Filters/affnet/layers/adaptive_pool.py create mode 100644 Adaptive Frequency Filters/affnet/layers/base_layer.py create mode 100644 Adaptive Frequency Filters/affnet/layers/conv_layer.py create mode 100644 Adaptive Frequency Filters/affnet/layers/conv_layer_complex.py create mode 100644 Adaptive Frequency Filters/affnet/layers/dropout.py create mode 100644 Adaptive Frequency Filters/affnet/layers/embedding.py create mode 100644 Adaptive Frequency Filters/affnet/layers/flatten.py create mode 100644 Adaptive Frequency Filters/affnet/layers/global_pool.py create mode 100644 Adaptive Frequency Filters/affnet/layers/identity.py create mode 100644 Adaptive Frequency Filters/affnet/layers/linear_attention.py create mode 100644 Adaptive Frequency Filters/affnet/layers/linear_layer.py create mode 100644 Adaptive Frequency Filters/affnet/layers/multi_head_attention.py create mode 100644 Adaptive Frequency Filters/affnet/layers/non_linear_layers.py create mode 100644 Adaptive Frequency Filters/affnet/layers/normalization/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/layers/normalization/batch_norm.py create mode 100644 Adaptive Frequency Filters/affnet/layers/normalization/group_norm.py create mode 100644 Adaptive Frequency Filters/affnet/layers/normalization/instance_norm.py create mode 100644 Adaptive Frequency Filters/affnet/layers/normalization/layer_norm.py create mode 100644 Adaptive Frequency Filters/affnet/layers/normalization/sync_batch_norm.py create mode 100644 Adaptive Frequency Filters/affnet/layers/normalization_layers.py create mode 100644 Adaptive Frequency Filters/affnet/layers/pixel_shuffle.py create mode 100644 Adaptive Frequency Filters/affnet/layers/pooling.py create mode 100644 Adaptive Frequency Filters/affnet/layers/positional_embedding.py create mode 100644 Adaptive Frequency Filters/affnet/layers/positional_encoding.py create mode 100644 Adaptive Frequency Filters/affnet/layers/random_layers.py create mode 100644 Adaptive Frequency Filters/affnet/layers/single_head_attention.py create mode 100644 Adaptive Frequency Filters/affnet/layers/softmax.py create mode 100644 Adaptive Frequency Filters/affnet/layers/stocastic_depth.py create mode 100644 Adaptive Frequency Filters/affnet/layers/upsample.py create mode 100644 Adaptive Frequency Filters/affnet/matcher_det/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/matcher_det/base_matcher.py create mode 100644 Adaptive Frequency Filters/affnet/matcher_det/ssd_matcher.py create mode 100644 Adaptive Frequency Filters/affnet/misc/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/misc/averaging_utils.py create mode 100644 Adaptive Frequency Filters/affnet/misc/box_utils.py create mode 100644 Adaptive Frequency Filters/affnet/misc/common.py create mode 100644 Adaptive Frequency Filters/affnet/misc/init_utils.py create mode 100644 Adaptive Frequency Filters/affnet/misc/profiler.py create mode 100644 Adaptive Frequency Filters/affnet/misc/third_party/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/misc/third_party/ssd_utils.py create mode 100644 Adaptive Frequency Filters/affnet/models/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/models/classification/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/models/classification/affnet.py create mode 100644 Adaptive Frequency Filters/affnet/models/classification/base_cls.py create mode 100644 Adaptive Frequency Filters/affnet/models/classification/config/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/models/classification/config/affnet.py create mode 100644 Adaptive Frequency Filters/affnet/models/detection/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/models/detection/base_detection.py create mode 100644 Adaptive Frequency Filters/affnet/models/detection/mask_rcnn.py create mode 100644 Adaptive Frequency Filters/affnet/models/detection/ssd.py create mode 100644 Adaptive Frequency Filters/affnet/models/detection/utils/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/models/detection/utils/rcnn_utils.py create mode 100644 Adaptive Frequency Filters/affnet/models/segmentation/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/models/segmentation/base_seg.py create mode 100644 Adaptive Frequency Filters/affnet/models/segmentation/enc_dec.py create mode 100644 Adaptive Frequency Filters/affnet/models/segmentation/heads/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/models/segmentation/heads/base_seg_head.py create mode 100644 Adaptive Frequency Filters/affnet/models/segmentation/heads/deeplabv3.py create mode 100644 Adaptive Frequency Filters/affnet/models/segmentation/heads/pspnet.py create mode 100644 Adaptive Frequency Filters/affnet/modules/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/modules/aff_block.py create mode 100644 Adaptive Frequency Filters/affnet/modules/aspp_block.py create mode 100644 Adaptive Frequency Filters/affnet/modules/base_module.py create mode 100644 Adaptive Frequency Filters/affnet/modules/cbam.py create mode 100644 Adaptive Frequency Filters/affnet/modules/complexFunctions.py create mode 100644 Adaptive Frequency Filters/affnet/modules/complexLayers.py create mode 100644 Adaptive Frequency Filters/affnet/modules/efficientnet.py create mode 100644 Adaptive Frequency Filters/affnet/modules/feature_pyramid.py create mode 100644 Adaptive Frequency Filters/affnet/modules/mobilenetv2.py create mode 100644 Adaptive Frequency Filters/affnet/modules/mobilevit_block.py create mode 100644 Adaptive Frequency Filters/affnet/modules/pspnet_module.py create mode 100644 Adaptive Frequency Filters/affnet/modules/resnet_modules.py create mode 100644 Adaptive Frequency Filters/affnet/modules/squeeze_excitation.py create mode 100644 Adaptive Frequency Filters/affnet/modules/ssd_heads.py create mode 100644 Adaptive Frequency Filters/affnet/modules/swin_transformer_block.py create mode 100644 Adaptive Frequency Filters/affnet/modules/transformer.py create mode 100644 Adaptive Frequency Filters/affnet/neural_augmentor/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/neural_augmentor/neural_aug.py create mode 100644 Adaptive Frequency Filters/affnet/neural_augmentor/utils/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/neural_augmentor/utils/neural_aug_utils.py create mode 100644 Adaptive Frequency Filters/affnet/text_encoders/__init__.py create mode 100644 Adaptive Frequency Filters/affnet/text_encoders/base_text_encoder.py create mode 100644 Adaptive Frequency Filters/affnet/text_encoders/transformer.py create mode 100644 Adaptive Frequency Filters/common/__init__.py create mode 100644 Adaptive Frequency Filters/data/__init__.py create mode 100644 Adaptive Frequency Filters/data/collate_fns/__init__.py create mode 100644 Adaptive Frequency Filters/data/collate_fns/collate_functions.py create mode 100644 Adaptive Frequency Filters/data/data_loaders.py create mode 100644 Adaptive Frequency Filters/data/datasets/__init__.py create mode 100644 Adaptive Frequency Filters/data/datasets/classification/__init__.py create mode 100644 Adaptive Frequency Filters/data/datasets/classification/imagenet.py create mode 100644 Adaptive Frequency Filters/data/datasets/classification/imagenet_fast.py create mode 100644 Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv.py create mode 100644 Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv_bitplane_fast.py create mode 100644 Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv_fast.py create mode 100644 Adaptive Frequency Filters/data/datasets/classification/imagenet_v2.py create mode 100644 Adaptive Frequency Filters/data/datasets/dataset_base.py create mode 100644 Adaptive Frequency Filters/data/datasets/detection/__init__.py create mode 100644 Adaptive Frequency Filters/data/datasets/detection/coco_base.py create mode 100644 Adaptive Frequency Filters/data/datasets/detection/coco_mask_rcnn.py create mode 100644 Adaptive Frequency Filters/data/datasets/detection/coco_ssd.py create mode 100644 Adaptive Frequency Filters/data/datasets/multi_modal_img_text/__init__.py create mode 100644 Adaptive Frequency Filters/data/datasets/multi_modal_img_text/base_multi_modal_img_text.py create mode 100644 Adaptive Frequency Filters/data/datasets/multi_modal_img_text/img_text_tar_dataset.py create mode 100644 Adaptive Frequency Filters/data/datasets/multi_modal_img_text/zero_shot/__init__.py create mode 100644 Adaptive Frequency Filters/data/datasets/multi_modal_img_text/zero_shot/base_zero_shot.py create mode 100644 Adaptive Frequency Filters/data/datasets/multi_modal_img_text/zero_shot/imagenet.py create mode 100644 Adaptive Frequency Filters/data/datasets/segmentation/__init__.py create mode 100644 Adaptive Frequency Filters/data/datasets/segmentation/ade20k.py create mode 100644 Adaptive Frequency Filters/data/datasets/segmentation/coco_segmentation.py create mode 100644 Adaptive Frequency Filters/data/datasets/segmentation/pascal_voc.py create mode 100644 Adaptive Frequency Filters/data/datasets/video_classification/__init__.py create mode 100644 Adaptive Frequency Filters/data/datasets/video_classification/kinetics.py create mode 100644 Adaptive Frequency Filters/data/loader/__init__.py create mode 100644 Adaptive Frequency Filters/data/loader/dataloader.py create mode 100644 Adaptive Frequency Filters/data/sampler/__init__.py create mode 100644 Adaptive Frequency Filters/data/sampler/base_sampler.py create mode 100644 Adaptive Frequency Filters/data/sampler/batch_sampler.py create mode 100644 Adaptive Frequency Filters/data/sampler/multi_scale_sampler.py create mode 100644 Adaptive Frequency Filters/data/sampler/utils.py create mode 100644 Adaptive Frequency Filters/data/sampler/variable_batch_sampler.py create mode 100644 Adaptive Frequency Filters/data/sampler/video_batch_sampler.py create mode 100644 Adaptive Frequency Filters/data/sampler/video_variable_seq_sampler.py create mode 100644 Adaptive Frequency Filters/data/text_tokenizer/__init__.py create mode 100644 Adaptive Frequency Filters/data/text_tokenizer/base_tokenizer.py create mode 100644 Adaptive Frequency Filters/data/text_tokenizer/clip_tokenizer.py create mode 100644 Adaptive Frequency Filters/data/transforms/__init__.py create mode 100644 Adaptive Frequency Filters/data/transforms/base_transforms.py create mode 100644 Adaptive Frequency Filters/data/transforms/image_opencv.py create mode 100644 Adaptive Frequency Filters/data/transforms/image_pil.py create mode 100644 Adaptive Frequency Filters/data/transforms/image_torch.py create mode 100644 Adaptive Frequency Filters/data/transforms/utils.py create mode 100644 Adaptive Frequency Filters/data/transforms/video.py create mode 100644 Adaptive Frequency Filters/data/video_reader/__init__.py create mode 100644 Adaptive Frequency Filters/data/video_reader/base_video_reader.py create mode 100644 Adaptive Frequency Filters/data/video_reader/default_video_reader.py create mode 100644 Adaptive Frequency Filters/data/video_reader/key_frame_reader.py create mode 100644 Adaptive Frequency Filters/engine/__init__.py create mode 100644 Adaptive Frequency Filters/engine/detection_utils/__init__.py create mode 100644 Adaptive Frequency Filters/engine/detection_utils/coco_map.py create mode 100644 Adaptive Frequency Filters/engine/eval_detection.py create mode 100644 Adaptive Frequency Filters/engine/eval_segmentation.py create mode 100644 Adaptive Frequency Filters/engine/evaluation_engine.py create mode 100644 Adaptive Frequency Filters/engine/segmentation_utils/__init__.py create mode 100644 Adaptive Frequency Filters/engine/segmentation_utils/cityscapes_iou.py create mode 100644 Adaptive Frequency Filters/engine/training_engine.py create mode 100644 Adaptive Frequency Filters/engine/utils.py create mode 100644 Adaptive Frequency Filters/loss_fn/__init__.py create mode 100644 Adaptive Frequency Filters/loss_fn/base_criteria.py create mode 100644 Adaptive Frequency Filters/loss_fn/base_neural_aug.py create mode 100644 Adaptive Frequency Filters/loss_fn/classification.py create mode 100644 Adaptive Frequency Filters/loss_fn/classification_loss_fns/__init__.py create mode 100644 Adaptive Frequency Filters/loss_fn/classification_loss_fns/binary_cross_entropy.py create mode 100644 Adaptive Frequency Filters/loss_fn/classification_loss_fns/cross_entropy.py create mode 100644 Adaptive Frequency Filters/loss_fn/classification_loss_fns/cross_entropy_with_neural_aug.py create mode 100644 Adaptive Frequency Filters/loss_fn/detection.py create mode 100644 Adaptive Frequency Filters/loss_fn/detection_loss_fns/__init__.py create mode 100644 Adaptive Frequency Filters/loss_fn/detection_loss_fns/mask_rcnn_loss.py create mode 100644 Adaptive Frequency Filters/loss_fn/detection_loss_fns/mask_rcnn_loss_with_neural_aug.py create mode 100644 Adaptive Frequency Filters/loss_fn/detection_loss_fns/ssd_multibox_loss.py create mode 100644 Adaptive Frequency Filters/loss_fn/detection_loss_fns/utils.py create mode 100644 Adaptive Frequency Filters/loss_fn/distillation.py create mode 100644 Adaptive Frequency Filters/loss_fn/distillation_loss_fns/__init__.py create mode 100644 Adaptive Frequency Filters/loss_fn/distillation_loss_fns/cls_kl_div_loss.py create mode 100644 Adaptive Frequency Filters/loss_fn/distillation_loss_fns/cls_kl_div_loss_neural_aug.py create mode 100644 Adaptive Frequency Filters/loss_fn/distillation_loss_fns/utils.py create mode 100644 Adaptive Frequency Filters/loss_fn/multi_modal_img_text.py create mode 100644 Adaptive Frequency Filters/loss_fn/multi_modal_img_text_loss_fns/__init__.py create mode 100644 Adaptive Frequency Filters/loss_fn/multi_modal_img_text_loss_fns/contrastive_loss_clip.py create mode 100644 Adaptive Frequency Filters/loss_fn/multi_modal_img_text_loss_fns/contrastive_loss_clip_with_neural_aug.py create mode 100644 Adaptive Frequency Filters/loss_fn/segmentation.py create mode 100644 Adaptive Frequency Filters/loss_fn/segmentation_loss_fns/__init__.py create mode 100644 Adaptive Frequency Filters/loss_fn/segmentation_loss_fns/cross_entropy.py create mode 100644 Adaptive Frequency Filters/loss_fn/segmentation_loss_fns/seg_cross_entropy_with_neural_aug.py create mode 100644 Adaptive Frequency Filters/loss_landscape/__init__.py create mode 100644 Adaptive Frequency Filters/loss_landscape/landscape_utils.py create mode 100644 Adaptive Frequency Filters/main_eval.py create mode 100644 Adaptive Frequency Filters/main_train.py create mode 100644 Adaptive Frequency Filters/metrics/__init__.py create mode 100644 Adaptive Frequency Filters/metrics/coco_map.py create mode 100644 Adaptive Frequency Filters/metrics/confusion_mat.py create mode 100644 Adaptive Frequency Filters/metrics/intersection_over_union.py create mode 100644 Adaptive Frequency Filters/metrics/metric_monitor.py create mode 100644 Adaptive Frequency Filters/metrics/psnr.py create mode 100644 Adaptive Frequency Filters/metrics/stats.py create mode 100644 Adaptive Frequency Filters/metrics/topk_accuracy.py create mode 100644 Adaptive Frequency Filters/optim/__init__.py create mode 100644 Adaptive Frequency Filters/optim/adam.py create mode 100644 Adaptive Frequency Filters/optim/adamw.py create mode 100644 Adaptive Frequency Filters/optim/base_optim.py create mode 100644 Adaptive Frequency Filters/optim/scheduler/__init__.py create mode 100644 Adaptive Frequency Filters/optim/scheduler/base_scheduler.py create mode 100644 Adaptive Frequency Filters/optim/scheduler/cosine.py create mode 100644 Adaptive Frequency Filters/optim/scheduler/cyclic.py create mode 100644 Adaptive Frequency Filters/optim/scheduler/fixed.py create mode 100644 Adaptive Frequency Filters/optim/scheduler/multi_step.py create mode 100644 Adaptive Frequency Filters/optim/scheduler/polynomial.py create mode 100644 Adaptive Frequency Filters/optim/sgd.py create mode 100644 Adaptive Frequency Filters/options/__init__.py create mode 100644 Adaptive Frequency Filters/options/opts.py create mode 100644 Adaptive Frequency Filters/options/parse_args.py create mode 100644 Adaptive Frequency Filters/options/utils.py create mode 100644 Adaptive Frequency Filters/requirements.txt create mode 100644 Adaptive Frequency Filters/requirements_docs.txt create mode 100644 Adaptive Frequency Filters/utils/__init__.py create mode 100644 Adaptive Frequency Filters/utils/checkpoint_utils.py create mode 100644 Adaptive Frequency Filters/utils/color_map.py create mode 100644 Adaptive Frequency Filters/utils/common_utils.py create mode 100644 Adaptive Frequency Filters/utils/ddp_utils.py create mode 100644 Adaptive Frequency Filters/utils/download_utils.py create mode 100644 Adaptive Frequency Filters/utils/download_utils_base.py create mode 100644 Adaptive Frequency Filters/utils/logger.py create mode 100644 Adaptive Frequency Filters/utils/math_utils.py create mode 100644 Adaptive Frequency Filters/utils/my_dataset_folder.py create mode 100644 Adaptive Frequency Filters/utils/pytorch_to_coreml.py create mode 100644 Adaptive Frequency Filters/utils/tensor_utils.py create mode 100644 Adaptive Frequency Filters/utils/third_party/__init__.py create mode 100644 Adaptive Frequency Filters/utils/third_party/ddp_functional_utils.py create mode 100644 Adaptive Frequency Filters/utils/visualization_utils.py diff --git a/Adaptive Frequency Filters/README.md b/Adaptive Frequency Filters/README.md index 4e28303..e6d860a 100644 --- a/Adaptive Frequency Filters/README.md +++ b/Adaptive Frequency Filters/README.md @@ -50,11 +50,10 @@ cd TokenMixers/AFFNet/ 3. Install required packages: ```bash -conda create -fyn ml-cvnets python=3.8 +conda create -fyn AFFNet python=3.8 python -m pip install wandb ptflops einops python -m pip install -r requirements.txt python -m pip install psutil torchstat tqdm -python -m pip install --editable . python -m pip install --upgrade fvcore python -m pip install complexPyTorch ``` diff --git a/Adaptive Frequency Filters/affnet/__init__.py b/Adaptive Frequency Filters/affnet/__init__.py new file mode 100644 index 0000000..5c9d261 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/__init__.py @@ -0,0 +1,48 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse + + +from options.utils import extend_selected_args_with_prefix +from affnet.misc.common import parameter_list +from affnet.anchor_generator import arguments_anchor_gen +from affnet.image_projection_layers import arguments_image_projection_head +from affnet.layers import arguments_nn_layers +from affnet.matcher_det import arguments_box_matcher +from affnet.misc.averaging_utils import arguments_ema, EMA +from affnet.misc.profiler import module_profile +from affnet.models import arguments_model, get_model +from affnet.models.detection.base_detection import DetectionPredTuple +from affnet.neural_augmentor import arguments_neural_augmentor +from affnet.text_encoders import arguments_text_encoder + + +def modeling_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + # model arguments + parser = arguments_model(parser) + # neural network layer argumetns + parser = arguments_nn_layers(parser) + # EMA arguments + parser = arguments_ema(parser) + # anchor generator arguments (for object detection) + parser = arguments_anchor_gen(parser) + # box matcher arguments (for object detection) + parser = arguments_box_matcher(parser) + # text encoder arguments (usually for multi-modal tasks) + parser = arguments_text_encoder(parser) + # image projection head arguments (usually for multi-modal tasks) + parser = arguments_image_projection_head(parser) + # neural aug arguments + parser = arguments_neural_augmentor(parser) + + # Add teacher as a prefix to enable distillation tasks + # keep it as the last entry + parser = extend_selected_args_with_prefix( + parser, check_string="--model", add_prefix="--teacher." + ) + + return parser diff --git a/Adaptive Frequency Filters/affnet/anchor_generator/__init__.py b/Adaptive Frequency Filters/affnet/anchor_generator/__init__.py new file mode 100644 index 0000000..1f2ae06 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/anchor_generator/__init__.py @@ -0,0 +1,86 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse +import os +import importlib + +from utils import logger +from utils.ddp_utils import is_master + +from .base_anchor_generator import BaseAnchorGenerator + +# register anchor generator +ANCHOR_GEN_REGISTRY = {} + + +def register_anchor_generator(name): + """Register anchor generators for object detection""" + + def register_class(cls): + if name in ANCHOR_GEN_REGISTRY: + raise ValueError( + "Cannot register duplicate anchor generator ({})".format(name) + ) + + if not issubclass(cls, BaseAnchorGenerator): + raise ValueError( + "Anchor generator ({}: {}) must extend BaseAnchorGenerator".format( + name, cls.__name__ + ) + ) + + ANCHOR_GEN_REGISTRY[name] = cls + return cls + + return register_class + + +def arguments_anchor_gen(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Arguments related to anchor generator for object detection""" + group = parser.add_argument_group("Anchor generator", "Anchor generator") + group.add_argument( + "--anchor-generator.name", type=str, help="Name of the anchor generator" + ) + + for k, v in ANCHOR_GEN_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + + return parser + + +def build_anchor_generator(opts, *args, **kwargs): + """Build anchor generator for object detection""" + anchor_gen_name = getattr(opts, "anchor_generator.name", None) + anchor_gen = None + if anchor_gen_name in ANCHOR_GEN_REGISTRY: + anchor_gen = ANCHOR_GEN_REGISTRY[anchor_gen_name](opts, *args, **kwargs) + else: + supported_anchor_gens = list(ANCHOR_GEN_REGISTRY.keys()) + supp_anchor_gen_str = ( + "Got {} as anchor generator. Supported anchor generators are:".format( + anchor_gen_name + ) + ) + for i, m_name in enumerate(supported_anchor_gens): + supp_anchor_gen_str += "\n\t {}: {}".format(i, logger.color_text(m_name)) + + if is_master(opts): + logger.error(supp_anchor_gen_str) + return anchor_gen + + +# automatically import the anchor generators +anchor_gen_dir = os.path.dirname(__file__) +for file in os.listdir(anchor_gen_dir): + path = os.path.join(anchor_gen_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + anc_gen = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("affnet.anchor_generator." + anc_gen) diff --git a/Adaptive Frequency Filters/affnet/anchor_generator/base_anchor_generator.py b/Adaptive Frequency Filters/affnet/anchor_generator/base_anchor_generator.py new file mode 100644 index 0000000..10678b6 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/anchor_generator/base_anchor_generator.py @@ -0,0 +1,99 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import Tensor +import argparse +from typing import Optional, Tuple, Union + + +class BaseAnchorGenerator(torch.nn.Module): + """ + Base class for anchor generators for the task of object detection. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__() + self.anchors_dict = dict() + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """ + Add anchor generator-specific arguments to the parser + """ + return parser + + def num_anchors_per_os(self): + """Returns anchors per output stride. Child classes must implement this function.""" + raise NotImplementedError + + @torch.no_grad() + def _generate_anchors( + self, + height: int, + width: int, + output_stride: int, + device: Optional[str] = "cpu", + *args, + **kwargs + ) -> Union[Tensor, Tuple[Tensor, ...]]: + raise NotImplementedError + + @torch.no_grad() + def _get_anchors( + self, + fm_height: int, + fm_width: int, + fm_output_stride: int, + device: Optional[str] = "cpu", + *args, + **kwargs + ) -> Union[Tensor, Tuple[Tensor, ...]]: + key = "h_{}_w_{}_os_{}".format(fm_height, fm_width, fm_output_stride) + if key not in self.anchors_dict: + default_anchors_ctr = self._generate_anchors( + height=fm_height, + width=fm_width, + output_stride=fm_output_stride, + device=device, + *args, + **kwargs + ) + self.anchors_dict[key] = default_anchors_ctr + return default_anchors_ctr + else: + return self.anchors_dict[key] + + @torch.no_grad() + def forward( + self, + fm_height: int, + fm_width: int, + fm_output_stride: int, + device: Optional[str] = "cpu", + *args, + **kwargs + ) -> Union[Tensor, Tuple[Tensor, ...]]: + """ + Returns anchors for the feature map + + Args: + fm_height (int): Height of the feature map + fm_width (int): Width of the feature map + fm_output_stride (int): Output stride of the feature map + device (Optional, str): Device (cpu or cuda). Defaults to cpu + + Returns: + Tensor or Tuple of Tensors + """ + return self._get_anchors( + fm_height=fm_height, + fm_width=fm_width, + fm_output_stride=fm_output_stride, + device=device, + *args, + **kwargs + ) diff --git a/Adaptive Frequency Filters/affnet/anchor_generator/ssd_anchor_generator.py b/Adaptive Frequency Filters/affnet/anchor_generator/ssd_anchor_generator.py new file mode 100644 index 0000000..ff37de7 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/anchor_generator/ssd_anchor_generator.py @@ -0,0 +1,199 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from typing import Optional +import numpy as np +import argparse +from itertools import product +from typing import List +from torch import Tensor + +from utils import logger + +from . import register_anchor_generator, BaseAnchorGenerator + + +@register_anchor_generator(name="ssd") +class SSDAnchorGenerator(BaseAnchorGenerator): + """ + This class generates anchors (or priors) ``on-the-fly`` for the + `single shot object detector (SSD) `_. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + output_strides = getattr( + opts, "anchor_generator.ssd.output_strides", [32, 64, 128, 256, -1] + ) + aspect_ratios = getattr( + opts, "anchor_generator.ssd.aspect_ratios", [[2, 3]] * len(output_strides) + ) + + min_ratio = getattr(opts, "anchor_generator.ssd.min_scale_ratio", 0.1) + max_ratio = getattr(opts, "anchor_generator.ssd.max_scale_ratio", 1.05) + no_clipping = getattr(opts, "anchor_generator.ssd.no_clipping", False) + + step = getattr(opts, "anchor_generator.ssd.step", [1]) + if isinstance(step, int): + step = [step] * len(output_strides) + elif isinstance(step, List) and len(step) <= len(output_strides): + step = step + [1] * (len(output_strides) - len(step)) + else: + logger.error( + "--anchor-generator.ssd.step should be either a list of ints with the same length as " + "the output strides OR an integer" + ) + + super().__init__() + aspect_ratios = [list(set(ar)) for ar in aspect_ratios] + output_strides_aspect_ratio = dict() + for k, v in zip(output_strides, aspect_ratios): + output_strides_aspect_ratio[k] = v + self.output_strides_aspect_ratio = output_strides_aspect_ratio + self.output_strides = output_strides + self.anchors_dict = dict() + + self.num_output_strides = len(output_strides) + self.num_aspect_ratios = len(aspect_ratios) + + scales = np.linspace(min_ratio, max_ratio, len(output_strides) + 1) + self.sizes = dict() + for i, s in enumerate(output_strides): + self.sizes[s] = { + "min": scales[i], + "max": (scales[i] * scales[i + 1]) ** 0.5, + "step": step[i], + } + + self.clip = not no_clipping + self.min_scale_ratio = min_ratio + self.max_scale_ratio = max_ratio + self.step = step + + def __repr__(self): + return "{}(min_scale_ratio={}, max_scale_ratio={}, n_output_strides={}, n_aspect_ratios={}, clipping={})".format( + self.__class__.__name__, + self.min_scale_ratio, + self.max_scale_ratio, + self.num_output_strides, + self.num_aspect_ratios, + self.clip, + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """ + Adds SSD anchor generator-specific arguments to the parser + """ + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--anchor-generator.ssd.output-strides", + nargs="+", + type=int, + help="Output strides of the feature maps for which we want to generate anchors", + ) + group.add_argument( + "--anchor-generator.ssd.aspect-ratios", + nargs="+", + type=float, + action="append", + help="Aspect ratios at each output stride", + ) + + # prior box arguments + # SSD sample priors between min and max box sizes. + # for example, if we use feature maps from three spatial levels (or output strides), then we + # sample width and height for anchor boxes as: + # scales = np.linspace(min_box_size, max_box_size, len(output_strides) + 1) + # min_box dimensions for the first feature map is scales[0] * feature_map_dimensions + # while the max_box dimensions will be sqrt(scales[0] * scales[1]) * feature_map dimensions. And so on + group.add_argument( + "--anchor-generator.ssd.min-scale-ratio", + type=float, + help="Min. scale ratio", + ) + group.add_argument( + "--anchor-generator.ssd.max-scale-ratio", + type=float, + help="Max. scale ratio", + ) + group.add_argument( + "--anchor-generator.ssd.no-clipping", + action="store_true", + help="Don't clip the anchors", + ) + group.add_argument( + "--anchor-generator.ssd.step", + type=int, + default=[1], + nargs="+", + help="Step between pixels", + ) + return parser + + def num_anchors_per_os(self) -> List: + """ + Returns anchors per output stride for SSD + """ + return [2 + 2 * len(ar) for os, ar in self.output_strides_aspect_ratio.items()] + + @torch.no_grad() + def _generate_anchors( + self, + height: int, + width: int, + output_stride: int, + device: Optional[str] = "cpu", + *args, + **kwargs + ) -> Tensor: + min_size_h = self.sizes[output_stride]["min"] + min_size_w = self.sizes[output_stride]["min"] + + max_size_h = self.sizes[output_stride]["max"] + max_size_w = self.sizes[output_stride]["max"] + aspect_ratio = self.output_strides_aspect_ratio[output_stride] + + step = max(1, self.sizes[output_stride]["step"]) + + default_anchors_ctr = [] + + start_step = max(0, step // 2) + + # Note that feature maps are in NCHW format + for y, x in product( + range(start_step, height, step), range(start_step, width, step) + ): + + # [x, y, w, h] format + cx = (x + 0.5) / width + cy = (y + 0.5) / height + + # small box size + default_anchors_ctr.append([cx, cy, min_size_w, min_size_h]) + + # big box size + default_anchors_ctr.append([cx, cy, max_size_w, max_size_h]) + + # change h/w ratio of the small sized box based on aspect ratios + for ratio in aspect_ratio: + ratio = ratio**0.5 + default_anchors_ctr.extend( + [ + [cx, cy, min_size_w * ratio, min_size_h / ratio], + [cx, cy, min_size_w / ratio, min_size_h * ratio], + ] + ) + + default_anchors_ctr = torch.tensor( + default_anchors_ctr, dtype=torch.float, device=device + ) + if self.clip: + default_anchors_ctr = torch.clamp(default_anchors_ctr, min=0.0, max=1.0) + + return default_anchors_ctr diff --git a/Adaptive Frequency Filters/affnet/image_projection_layers/__init__.py b/Adaptive Frequency Filters/affnet/image_projection_layers/__init__.py new file mode 100644 index 0000000..33d9f33 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/image_projection_layers/__init__.py @@ -0,0 +1,98 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import importlib +import argparse +from typing import Optional + +from utils import logger + +from .base_image_projection import BaseImageProjectionHead + + +IMAGE_PROJECTION_HEAD_REGISTRY = {} + + +def register_image_projection_head(name): + # register the image projection head class + def register_image_projection_head_class(cls): + if name in IMAGE_PROJECTION_HEAD_REGISTRY: + raise ValueError( + "Cannot register duplicate image projection layer class ({})".format( + name + ) + ) + + if not issubclass(cls, BaseImageProjectionHead): + raise ValueError( + "Image projection layer class ({}: {}) must extend BaseImageProjection".format( + name, cls.__name__ + ) + ) + + IMAGE_PROJECTION_HEAD_REGISTRY[name] = cls + return cls + + return register_image_projection_head_class + + +def arguments_image_projection_head( + parser: argparse.ArgumentParser, +) -> argparse.ArgumentParser: + # add arguments for base image projection layer + parser = BaseImageProjectionHead.add_arguments(parser) + + # add class specific arguments + for k, v in IMAGE_PROJECTION_HEAD_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + return parser + + +def supported_str(layer_name: Optional[str] = None) -> None: + """Helper utility to print supported image projection heads.""" + supp_list = list(IMAGE_PROJECTION_HEAD_REGISTRY.keys()) + if layer_name is None: + supp_str = "Image projection head name can't be None. \n Supported heads are:" + else: + supp_str = "Image projection head ({}) is not yet supported. \n Supported heads are:".format( + layer_name + ) + for t_name in supp_list: + supp_str += "\n\t{}".format(t_name) + logger.error(supp_str + "\n") + + +def build_image_projection_head( + opts, in_dim: int, out_dim: int, *args, **kwargs +) -> BaseImageProjectionHead: + """Helper function to build the text encoder""" + projection_head_name = getattr(opts, "model.image_projection_head.name", None) + if projection_head_name is None: + supported_str(projection_head_name) + + if projection_head_name in list(IMAGE_PROJECTION_HEAD_REGISTRY.keys()): + return IMAGE_PROJECTION_HEAD_REGISTRY[projection_head_name]( + opts, in_dim, out_dim, *args, **kwargs + ) + else: + supported_str(projection_head_name) + + +# automatically import the image projection heads +image_projection_head_dir = os.path.dirname(__file__) + +for file in os.listdir(image_projection_head_dir): + path = os.path.join(image_projection_head_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + proj_head_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module( + "affnet.image_projection_layers." + proj_head_name + ) diff --git a/Adaptive Frequency Filters/affnet/image_projection_layers/attention_pool_2d.py b/Adaptive Frequency Filters/affnet/image_projection_layers/attention_pool_2d.py new file mode 100644 index 0000000..0ef7d0d --- /dev/null +++ b/Adaptive Frequency Filters/affnet/image_projection_layers/attention_pool_2d.py @@ -0,0 +1,190 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +from torch.nn import functional as F +import argparse +from typing import Optional + +from affnet.layers import PositionalEmbedding, MultiHeadAttention +from utils import logger + +from . import BaseImageProjectionHead, register_image_projection_head + + +@register_image_projection_head(name="attention_pool_nchw2nc") +class AttentionPool2dHead(BaseImageProjectionHead): + """This class implements attention pooling layer, as + described in `Clip `_, and should be + used for CNN-style models, including MobileViTs""" + + def __init__(self, opts, in_dim: int, out_dim: int, *args, **kwargs) -> None: + super().__init__(opts, *args, **kwargs) + + num_embeddings = getattr( + opts, + "model.image_projection_head.attention_pool_nchw2nc.num_pos_embeddings", + None, + ) + if num_embeddings is None: + logger.error( + "Number of embeddings can't be None in {}. Please specify using " + "--model.image-projection.attention-pool-2d.num-pos-embeddings argument".format( + self.__class__.__name__ + ) + ) + sin_pos_emb = getattr( + opts, + "model.image_projection_head.attention_pool_nchw2nc.use_sinusoidal_pos_embeddings", + False, + ) + num_heads = getattr( + opts, "model.image_projection_head.attention_pool_nchw2nc.num_attn_heads", 8 + ) + + self.use_pytorch_mha = getattr( + opts, + "model.image_projection_head.attention_pool_nchw2nc.use_pytorch_mha", + False, + ) + + self.positional_embedding = PositionalEmbedding( + opts, + num_embeddings=num_embeddings, + embedding_dim=in_dim, + padding_idx=None, + is_learnable=not sin_pos_emb, + sequence_first=self.use_pytorch_mha, + ) + self.multi_head_attn = MultiHeadAttention( + embed_dim=in_dim, num_heads=num_heads, output_dim=out_dim + ) + + self.embed_dim = in_dim + self.projection_dim = out_dim if out_dim is not None else in_dim + self.sin_pos_emb = sin_pos_emb + self.normalize_features = not getattr( + opts, + "model.image_projection_head.attention_pool_nchw2nc.no_feature_normalization", + False, + ) + + self.reset_parameters() + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--model.image-projection-head.attention-pool-nchw2nc.num-pos-embeddings", + type=int, + default=None, + help="Number of positional embeddings", + ) + + group.add_argument( + "--model.image-projection-head.attention-pool-nchw2nc.use-sinusoidal-pos-embeddings", + action="store_true", + help="Use sinusoidal positional embeddings instead of learnable", + ) + + group.add_argument( + "--model.image-projection-head.attention-pool-nchw2nc.num-attn-heads", + type=int, + default=8, + help="Number of attention heads in {}".format(cls.__name__), + ) + + group.add_argument( + "--model.image-projection-head.attention-pool-nchw2nc.no-feature-normalization", + action="store_true", + help="Don't normalize image features", + ) + + group.add_argument( + "--model.image-projection-head.attention-pool-nchw2nc.use-pytorch-mha", + action="store_true", + help="Use Pytorch Multi-head attention", + ) + + return parser + + def reset_parameters(self): + std = self.projection_dim**-0.5 + nn.init.normal_(self.multi_head_attn.qkv_proj.weight, mean=0.0, std=std) + nn.init.normal_(self.multi_head_attn.out_proj.weight, mean=0.0, std=std) + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: + + assert ( + x.dim() == 4 + ), "Input should be 4-dimensional (Batch, in_channels, height, width). Got: {}".format( + x.shape + ) + # x is [batch, in_channels, height, width] + # For CNN-style architectures, including MobileViTs + batch_size, in_channels, in_height, in_width = x.shape + + # Flatten the feature map + # [batch, in_channels, height, width] --> [batch, in_channels, height*width] + x = x.reshape(batch_size, in_channels, in_height * in_width) + + if self.use_pytorch_mha: + # we need sequence first. + # [batch, in_channels, height*width] --> [height*width, batch, in_channels] + x = x.permute(2, 0, 1) + + # global pool + # [height*width, batch, in_channels] --> [1, batch, in_channels] + global_token = torch.mean(x, dim=0, keepdim=True) + + num_pixels = x.shape[0] + + # add positional embedding to pixels + pos_emb = self.positional_embedding(num_pixels).to( + device=x.device, dtype=x.dtype + ) + x = x + pos_emb + + # concat the global token with pixel tokens + # [1, batch, in_channels] || [height*width, batch, in_channels] --> [1 + height*width, batch, in_channels] + x = torch.cat([global_token, x], dim=0) + + # do attention + x = self.multi_head_attn(x, use_pytorch_mha=True) + + # extract embeddings corresponding to global token + x = x[0] + else: + # [batch, in_channels, height*width] --> # [batch, height*width, in_channels] + x = x.transpose(1, 2) + + # global pool + # [batch, height*width, in_channels] --> [batch, 1, in_channels] + global_token = torch.mean(x, dim=1, keepdim=True) + + num_pixels = x.shape[1] + # add positional embedding to pixels + pos_emb = self.positional_embedding(num_pixels).to( + device=x.device, dtype=x.dtype + ) + x = x + pos_emb + + # concat the global token with pixel tokens + # [batch, 1, in_channels] || [batch, height*width, in_channels] --> [batch, 1 + height*width, in_channels] + x = torch.cat([global_token, x], dim=1) + + # do attention + x = self.multi_head_attn(x, use_pytorch_mha=False) + + # extract embeddings corresponding to global token + x = x[:, 0] + + if self.normalize_features: + x = F.normalize(x, dim=-1) + return x diff --git a/Adaptive Frequency Filters/affnet/image_projection_layers/base_image_projection.py b/Adaptive Frequency Filters/affnet/image_projection_layers/base_image_projection.py new file mode 100644 index 0000000..ed4bada --- /dev/null +++ b/Adaptive Frequency Filters/affnet/image_projection_layers/base_image_projection.py @@ -0,0 +1,64 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +import argparse +from typing import Optional, Tuple, Dict + +from affnet import parameter_list + + +class BaseImageProjectionHead(nn.Module): + """Base class that projects image representations to the same space as text representations""" + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__() + + self.lr_mult = getattr(opts, "model.image_projection_head.lr_multiplier", 1.0) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + """Add model specific arguments""" + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--model.image-projection-head.name", + type=str, + default=None, + help="Name of the image projection head", + ) + + group.add_argument( + "--model.image-projection-head.lr-multiplier", + type=float, + default=1.0, + help="LR multiplier for image projection head", + ) + + return parser + + def reset_parameters(self) -> None: + """Reset weights of a given layer""" + raise NotImplementedError + + def get_trainable_parameters( + self, + weight_decay: Optional[float] = 0.0, + no_decay_bn_filter_bias: Optional[bool] = False, + *args, + **kwargs + ): + param_list = parameter_list( + named_parameters=self.named_parameters, + weight_decay=weight_decay, + no_decay_bn_filter_bias=no_decay_bn_filter_bias, + ) + return param_list, [self.lr_mult] * len(param_list) + + def forward(self, input: Dict, *args, **kwargs) -> Dict: + raise NotImplementedError diff --git a/Adaptive Frequency Filters/affnet/image_projection_layers/global_pool_2d.py b/Adaptive Frequency Filters/affnet/image_projection_layers/global_pool_2d.py new file mode 100644 index 0000000..deb0eb8 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/image_projection_layers/global_pool_2d.py @@ -0,0 +1,69 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse + +import torch +from torch import nn, Tensor +from torch.nn import functional as F + +from affnet.layers import GlobalPool +from . import BaseImageProjectionHead, register_image_projection_head + + +@register_image_projection_head(name="global_pool_nchw2nc") +class GlobalPool2D(BaseImageProjectionHead): + """This class implements global pooling with linear projection""" + + def __init__(self, opts, in_dim: int, out_dim: int, *args, **kwargs) -> None: + super().__init__(opts, *args, **kwargs) + + scale = in_dim**-0.5 + self.pool = GlobalPool(pool_type="mean", keep_dim=False) + self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim))) + self.in_dim = in_dim + self.out_dim = out_dim + + self.feature_normalization = not getattr( + opts, + "model.image_projection_head.global_pool_nchw2nc.no_feature_normalization", + False, + ) + + self.reset_parameters() + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--model.image-projection-head.global-pool-nchw2nc.no-feature-normalization", + action="store_true", + help="Don't normalize image features", + ) + + return parser + + def reset_parameters(self): + pass + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: + # x is of shape [batch, in_dim] + assert ( + x.dim() == 4 + ), "Input should be 4-dimensional (Batch x in_dim x in_height x in_width). Got: {}".format( + x.shape + ) + + # [batch, in_dim, in_height, in_width] --> [batch, in_dim] + x = self.pool(x) + # [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim] + x = x @ self.proj + if self.feature_normalization: + x = F.normalize(x, dim=-1) + return x diff --git a/Adaptive Frequency Filters/affnet/image_projection_layers/simple_projection_head.py b/Adaptive Frequency Filters/affnet/image_projection_layers/simple_projection_head.py new file mode 100644 index 0000000..6b2b87d --- /dev/null +++ b/Adaptive Frequency Filters/affnet/image_projection_layers/simple_projection_head.py @@ -0,0 +1,62 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +from torch.nn import functional as F +import argparse + +from . import BaseImageProjectionHead, register_image_projection_head + + +@register_image_projection_head(name="simple_projection_nc2nc") +class SimpleImageProjectionHead(BaseImageProjectionHead): + """This class implements simple projection head""" + + def __init__(self, opts, in_dim: int, out_dim: int, *args, **kwargs) -> None: + super().__init__(opts, *args, **kwargs) + + scale = in_dim**-0.5 + self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim))) + self.in_dim = in_dim + self.out_dim = out_dim + + self.feature_normalizaiton = not getattr( + opts, + "model.image_projection_head.simple_projection_nc2nc.no_feature_normalization", + False, + ) + + self.reset_parameters() + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--model.image-projection-head.simple-projection-nc2nc.no-feature-normalization", + action="store_true", + help="Don't normalize image features", + ) + + return parser + + def reset_parameters(self): + pass + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: + # x is of shape [batch, in_dim] + assert ( + x.dim() == 2 + ), "Input should be 2-dimensional (Batch x in_dim). Got: {}".format(x.shape) + + # [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim] + x = x @ self.proj + if self.feature_normalizaiton: + x = F.normalize(x, dim=-1) + return x diff --git a/Adaptive Frequency Filters/affnet/layers/__init__.py b/Adaptive Frequency Filters/affnet/layers/__init__.py new file mode 100644 index 0000000..4cc56ff --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/__init__.py @@ -0,0 +1,101 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse +import os +import importlib, inspect + +from .base_layer import BaseLayer +from .conv_layer import ( + ConvLayer, + NormActLayer, + TransposeConvLayer, + ConvLayer3d, + SeparableConv, +) +from .linear_layer import LinearLayer, GroupLinear +from .global_pool import GlobalPool +from .identity import Identity +from .non_linear_layers import get_activation_fn +from .normalization_layers import get_normalization_layer, norm_layers_tuple +from .pixel_shuffle import PixelShuffle +from .upsample import UpSample +from .pooling import MaxPool2d, AvgPool2d +from .normalization_layers import AdjustBatchNormMomentum +from .adaptive_pool import AdaptiveAvgPool2d +from .flatten import Flatten +from .multi_head_attention import MultiHeadAttention +from .dropout import Dropout, Dropout2d +from .single_head_attention import SingleHeadAttention +from .softmax import Softmax +from .linear_attention import LinearSelfAttention +from .embedding import Embedding +from .stocastic_depth import StochasticDepth +from .positional_embedding import PositionalEmbedding + +__all__ = [ + "ConvLayer", + "ConvLayer3d", + "SeparableConv", + "NormActLayer", + "TransposeConvLayer", + "LinearLayer", + "GroupLinear", + "GlobalPool", + "Identity", + "PixelShuffle", + "UpSample", + "MaxPool2d", + "AvgPool2d", + "Dropout", + "Dropout2d", + "AdjustBatchNormMomentum", + "Flatten", + "MultiHeadAttention", + "SingleHeadAttention", + "Softmax", + "LinearSelfAttention", + "Embedding", + "PositionalEmbedding", + "norm_layers_tuple", + "StochasticDepth", +] + + +# iterate through all classes and fetch layer specific arguments +def layer_specific_args(parser: argparse.ArgumentParser): + layer_dir = os.path.dirname(__file__) + parsed_layers = [] + for file in os.listdir(layer_dir): + path = os.path.join(layer_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + layer_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("affnet.layers." + layer_name) + for name, cls in inspect.getmembers(module, inspect.isclass): + if issubclass(cls, BaseLayer) and name not in parsed_layers: + parser = cls.add_arguments(parser) + parsed_layers.append(name) + return parser + + +def arguments_nn_layers(parser: argparse.ArgumentParser): + # Retrieve layer specific arguments + parser = layer_specific_args(parser) + + # activation and normalization arguments + from affnet.layers.activation import arguments_activation_fn + + parser = arguments_activation_fn(parser) + + from affnet.layers.normalization import arguments_norm_layers + + parser = arguments_norm_layers(parser) + + return parser diff --git a/Adaptive Frequency Filters/affnet/layers/activation/__init__.py b/Adaptive Frequency Filters/affnet/layers/activation/__init__.py new file mode 100644 index 0000000..3195b49 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/activation/__init__.py @@ -0,0 +1,107 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import importlib +import argparse +from typing import Optional + +import torch.nn + +from utils import logger + +SUPPORTED_ACT_FNS = [] +ACT_FN_REGISTRY = {} + + +def register_act_fn(name): + def register_fn(cls): + if name in SUPPORTED_ACT_FNS: + raise ValueError( + "Cannot register duplicate activation function ({})".format(name) + ) + SUPPORTED_ACT_FNS.append(name) + ACT_FN_REGISTRY[name] = cls + return cls + + return register_fn + + +def arguments_activation_fn(parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Non-linear functions", description="Non-linear functions" + ) + + group.add_argument( + "--model.activation.name", + default="relu", + type=str, + help="Non-linear function name", + ) + group.add_argument( + "--model.activation.inplace", + action="store_true", + help="Use non-linear functions inplace", + ) + group.add_argument( + "--model.activation.neg-slope", + default=0.1, + type=float, + help="Negative slope in leaky relu function", + ) + group.add_argument( + "--model.activation.sparsity_threshold", + default=0.01, + type=float, + help="sparsity_threshold for the mask", + ) + + return parser + + +def build_activation_layer( + act_type: Optional[str] = "relu", + num_parameters: Optional[int] = -1, + inplace: Optional[bool] = True, + negative_slope: Optional[float] = 0.1, + *args, + **kwargs +) -> torch.nn.Module: + """ + Helper function to build the activation function + """ + if act_type is None: + act_type = "none" + act_type = act_type.lower() + act_layer = None + if act_type in ACT_FN_REGISTRY: + act_layer = ACT_FN_REGISTRY[act_type]( + num_parameters=num_parameters, + inplace=inplace, + negative_slope=negative_slope, + *args, + **kwargs + ) + else: + logger.error( + "Supported activation layers are: {}. Supplied argument is: {}".format( + SUPPORTED_ACT_FNS, act_type + ) + ) + return act_layer + + +# automatically import different activation functions +act_dir = os.path.dirname(__file__) +for file in os.listdir(act_dir): + path = os.path.join(act_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("affnet.layers.activation." + model_name) diff --git a/Adaptive Frequency Filters/affnet/layers/activation/gelu.py b/Adaptive Frequency Filters/affnet/layers/activation/gelu.py new file mode 100644 index 0000000..b63c629 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/activation/gelu.py @@ -0,0 +1,25 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Tuple + +from . import register_act_fn + + +@register_act_fn(name="gelu") +class GELU(nn.GELU): + """ + Applies the `Gaussian Error Linear Units `_ function + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__() + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + return input, 0.0, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/activation/hard_sigmoid.py b/Adaptive Frequency Filters/affnet/layers/activation/hard_sigmoid.py new file mode 100644 index 0000000..fa51c88 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/activation/hard_sigmoid.py @@ -0,0 +1,32 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from torch.nn import functional as F +from typing import Optional, Tuple + +from . import register_act_fn + + +@register_act_fn(name="hard_sigmoid") +class Hardsigmoid(nn.Hardsigmoid): + """ + Applies the `Hard Sigmoid `_ function + """ + + def __init__(self, inplace: Optional[bool] = False, *args, **kwargs) -> None: + super().__init__(inplace=inplace) + + def forward(self, input: Tensor, *args, **kwargs) -> Tensor: + if hasattr(F, "hardsigmoid"): + return F.hardsigmoid(input, self.inplace) + else: + return F.relu(input + 3) / 6 + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + return input, 0.0, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/activation/hard_swish.py b/Adaptive Frequency Filters/affnet/layers/activation/hard_swish.py new file mode 100644 index 0000000..d5cba2d --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/activation/hard_swish.py @@ -0,0 +1,33 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from torch.nn import functional as F +from typing import Tuple, Optional +from . import register_act_fn + + +@register_act_fn(name="hard_swish") +class Hardswish(nn.Hardswish): + """ + Applies the HardSwish function, as described in the paper + `Searching for MobileNetv3 `_ + """ + + def __init__(self, inplace: Optional[bool] = False, *args, **kwargs) -> None: + super().__init__(inplace=inplace) + + def forward(self, input: Tensor, *args, **kwargs) -> Tensor: + if hasattr(F, "hardswish"): + return F.hardswish(input, self.inplace) + else: + x_hard_sig = F.relu(input + 3) / 6 + return input * x_hard_sig + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + return input, 0.0, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/activation/leaky_relu.py b/Adaptive Frequency Filters/affnet/layers/activation/leaky_relu.py new file mode 100644 index 0000000..928baa7 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/activation/leaky_relu.py @@ -0,0 +1,32 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Tuple, Optional + +from . import register_act_fn + + +@register_act_fn(name="leaky_relu") +class LeakyReLU(nn.LeakyReLU): + """ + Applies a leaky relu function. See `Rectifier Nonlinearities Improve Neural Network Acoustic Models` + for more details. + """ + + def __init__( + self, + negative_slope: Optional[float] = 1e-2, + inplace: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__(negative_slope=negative_slope, inplace=inplace) + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + return input, 0.0, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/activation/prelu.py b/Adaptive Frequency Filters/affnet/layers/activation/prelu.py new file mode 100644 index 0000000..8ff5f99 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/activation/prelu.py @@ -0,0 +1,29 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Optional, Tuple + +from . import register_act_fn + + +@register_act_fn(name="prelu") +class PReLU(nn.PReLU): + """ + Applies the `Parametric Rectified Linear Unit `_ function + """ + + def __init__( + self, + num_parameters: Optional[int] = 1, + init: Optional[float] = 0.25, + *args, + **kwargs + ) -> None: + super().__init__(num_parameters=num_parameters, init=init) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + return input, 0.0, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/activation/relu.py b/Adaptive Frequency Filters/affnet/layers/activation/relu.py new file mode 100644 index 0000000..fef781c --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/activation/relu.py @@ -0,0 +1,23 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Optional, Tuple + +from . import register_act_fn + + +@register_act_fn(name="relu") +class ReLU(nn.ReLU): + """ + Applies Rectified Linear Unit function + """ + + def __init__(self, inplace: Optional[bool] = False, *args, **kwargs) -> None: + super().__init__(inplace=inplace) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + return input, 0.0, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/activation/relu6.py b/Adaptive Frequency Filters/affnet/layers/activation/relu6.py new file mode 100644 index 0000000..e946340 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/activation/relu6.py @@ -0,0 +1,23 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Optional, Tuple + +from . import register_act_fn + + +@register_act_fn(name="relu6") +class ReLU6(nn.ReLU6): + """ + Applies the ReLU6 function + """ + + def __init__(self, inplace: Optional[bool] = False, *args, **kwargs) -> None: + super().__init__(inplace=inplace) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + return input, 0.0, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/activation/sigmoid.py b/Adaptive Frequency Filters/affnet/layers/activation/sigmoid.py new file mode 100644 index 0000000..c102b7f --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/activation/sigmoid.py @@ -0,0 +1,23 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Tuple + +from . import register_act_fn + + +@register_act_fn(name="sigmoid") +class Sigmoid(nn.Sigmoid): + """ + Applies the sigmoid function + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__() + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + return input, 0.0, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/activation/swish.py b/Adaptive Frequency Filters/affnet/layers/activation/swish.py new file mode 100644 index 0000000..1e5f353 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/activation/swish.py @@ -0,0 +1,23 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Optional, Tuple + +from . import register_act_fn + + +@register_act_fn(name="swish") +class Swish(nn.SiLU): + """ + Applies the `Swish (also known as SiLU) `_ function. + """ + + def __init__(self, inplace: Optional[bool] = False, *args, **kwargs) -> None: + super().__init__(inplace=inplace) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + return input, 0.0, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/activation/tanh.py b/Adaptive Frequency Filters/affnet/layers/activation/tanh.py new file mode 100644 index 0000000..949990f --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/activation/tanh.py @@ -0,0 +1,23 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Tuple + +from . import register_act_fn + + +@register_act_fn(name="tanh") +class Tanh(nn.Tanh): + """ + Applies Tanh function + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__() + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + return input, 0.0, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/adaptive_pool.py b/Adaptive Frequency Filters/affnet/layers/adaptive_pool.py new file mode 100644 index 0000000..b82ec08 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/adaptive_pool.py @@ -0,0 +1,32 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Union, Tuple + + +class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d): + """ + Applies a 2D adaptive average pooling over an input tensor. + + Args: + output_size (Optional, int or Tuple[int, int]): The target output size. If a single int :math:`h` is passed, + then a square output of size :math:`hxh` is produced. If a tuple of size :math:`hxw` is passed, then an + output of size `hxw` is produced. Default is 1. + Shape: + - Input: :math:`(N, C, H, W)` where :math:`N` is the batch size, :math:`C` is the number of input channels, + :math:`H` is the input height, and :math:`W` is the input width + - Output: :math:`(N, C, h, h)` or :math:`(N, C, h, w)` + """ + + def __init__( + self, output_size: Union[int, Tuple[int, int]] = 1, *args, **kwargs + ) -> None: + super().__init__(output_size=output_size) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + input = self.forward(input) + return input, 0.0, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/base_layer.py b/Adaptive Frequency Filters/affnet/layers/base_layer.py new file mode 100644 index 0000000..a2bf4ce --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/base_layer.py @@ -0,0 +1,32 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +import argparse +from typing import Any, Tuple + + +class BaseLayer(nn.Module): + """ + Base class for neural network layers + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__() + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Add layer specific arguments""" + return parser + + def forward(self, *args, **kwargs) -> Any: + pass + + def profile_module(self, *args, **kwargs) -> Tuple[Tensor, float, float]: + raise NotImplementedError + + def __repr__(self): + return "{}".format(self.__class__.__name__) diff --git a/Adaptive Frequency Filters/affnet/layers/conv_layer.py b/Adaptive Frequency Filters/affnet/layers/conv_layer.py new file mode 100644 index 0000000..e14c287 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/conv_layer.py @@ -0,0 +1,802 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +from typing import Optional, Union, Tuple +import argparse + +from utils import logger + +from .base_layer import BaseLayer +from .normalization_layers import get_normalization_layer +from .non_linear_layers import get_activation_fn + + +class Conv2d(nn.Conv2d): + """ + Applies a 2D convolution over an input + + Args: + in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})` + out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})` + kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution. + stride (Union[int, Tuple[int, int]]): Stride for convolution. Defaults to 1 + padding (Union[int, Tuple[int, int]]): Padding for convolution. Defaults to 0 + dilation (Union[int, Tuple[int, int]]): Dilation rate for convolution. Default: 1 + groups (Optional[int]): Number of groups in convolution. Default: 1 + bias (bool): Use bias. Default: ``False`` + padding_mode (Optional[str]): Padding mode. Default: ``zeros`` + + use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True`` + use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization). + Default: ``True`` + act_name (Optional[str]): Use specific activation function. Overrides the one specified in command line args. + + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = 1, + padding: Optional[Union[int, Tuple[int, int]]] = 0, + dilation: Optional[Union[int, Tuple[int, int]]] = 1, + groups: Optional[int] = 1, + bias: Optional[bool] = False, + padding_mode: Optional[str] = "zeros", + *args, + **kwargs + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + ) + + +class ConvLayer(BaseLayer): + """ + Applies a 2D convolution over an input + + Args: + opts: command line arguments + in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})` + out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})` + kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution. + stride (Union[int, Tuple[int, int]]): Stride for convolution. Default: 1 + dilation (Union[int, Tuple[int, int]]): Dilation rate for convolution. Default: 1 + padding (Union[int, Tuple[int, int]]): Padding for convolution. When not specified, + padding is automatically computed based on kernel size + and dilation rage. Default is ``None`` + groups (Optional[int]): Number of groups in convolution. Default: ``1`` + bias (Optional[bool]): Use bias. Default: ``False`` + padding_mode (Optional[str]): Padding mode. Default: ``zeros`` + use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True`` + use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization). + Default: ``True`` + act_name (Optional[str]): Use specific activation function. Overrides the one specified in command line args. + + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` + + .. note:: + For depth-wise convolution, `groups=C_{in}=C_{out}`. + """ + + def __init__( + self, + opts, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = 1, + dilation: Optional[Union[int, Tuple[int, int]]] = 1, + padding: Optional[Union[int, Tuple[int, int]]] = None, + groups: Optional[int] = 1, + bias: Optional[bool] = False, + padding_mode: Optional[str] = "zeros", + use_norm: Optional[bool] = True, + use_act: Optional[bool] = True, + act_name: Optional[str] = None, + *args, + **kwargs + ) -> None: + super().__init__() + + if use_norm: + norm_type = getattr(opts, "model.normalization.name", "batch_norm") + if norm_type is not None and norm_type.find("batch") > -1: + assert not bias, "Do not use bias when using normalization layers." + elif norm_type is not None and norm_type.find("layer") > -1: + bias = True + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + + if isinstance(stride, int): + stride = (stride, stride) + + if isinstance(dilation, int): + dilation = (dilation, dilation) + + assert isinstance(kernel_size, Tuple) + assert isinstance(stride, Tuple) + assert isinstance(dilation, Tuple) + + if padding is None: + padding = ( + int((kernel_size[0] - 1) / 2) * dilation[0], + int((kernel_size[1] - 1) / 2) * dilation[1], + ) + + if in_channels % groups != 0: + logger.error( + "Input channels are not divisible by groups. {}%{} != 0 ".format( + in_channels, groups + ) + ) + if out_channels % groups != 0: + logger.error( + "Output channels are not divisible by groups. {}%{} != 0 ".format( + out_channels, groups + ) + ) + + block = nn.Sequential() + + conv_layer = Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + ) + + block.add_module(name="conv", module=conv_layer) + + self.norm_name = None + if use_norm: + norm_layer = get_normalization_layer(opts=opts, num_features=out_channels) + block.add_module(name="norm", module=norm_layer) + self.norm_name = norm_layer.__class__.__name__ + + self.act_name = None + act_type = ( + getattr(opts, "model.activation.name", "prelu") + if act_name is None + else act_name + ) + + if act_type is not None and use_act: + neg_slope = getattr(opts, "model.activation.neg_slope", 0.1) + inplace = getattr(opts, "model.activation.inplace", False) + act_layer = get_activation_fn( + act_type=act_type, + inplace=inplace, + negative_slope=neg_slope, + num_parameters=out_channels, + ) + block.add_module(name="act", module=act_layer) + self.act_name = act_layer.__class__.__name__ + + self.block = block + + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.groups = groups + self.kernel_size = conv_layer.kernel_size + self.bias = bias + self.dilation = dilation + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + cls_name = "{} arguments".format(cls.__name__) + group = parser.add_argument_group(title=cls_name, description=cls_name) + group.add_argument( + "--model.layer.conv-init", + type=str, + default="kaiming_normal", + help="Init type for conv layers", + ) + parser.add_argument( + "--model.layer.conv-init-std-dev", + type=float, + default=None, + help="Std deviation for conv layers", + ) + return parser + + def forward(self, x: Tensor) -> Tensor: + return self.block(x) + + def __repr__(self): + repr_str = self.block[0].__repr__() + repr_str = repr_str[:-1] + + if self.norm_name is not None: + repr_str += ", normalization={}".format(self.norm_name) + + if self.act_name is not None: + repr_str += ", activation={}".format(self.act_name) + repr_str += ")" + return repr_str + + def profile_module(self, input: Tensor) -> (Tensor, float, float): + if input.dim() != 4: + logger.error( + "Conv2d requires 4-dimensional input (BxCxHxW). Provided input has shape: {}".format( + input.size() + ) + ) + + b, in_c, in_h, in_w = input.size() + assert in_c == self.in_channels, "{}!={}".format(in_c, self.in_channels) + + stride_h, stride_w = self.stride + groups = self.groups + + out_h = in_h // stride_h + out_w = in_w // stride_w + + k_h, k_w = self.kernel_size + + # compute MACS + macs = (k_h * k_w) * (in_c * self.out_channels) * (out_h * out_w) * 1.0 + macs /= groups + + if self.bias: + macs += self.out_channels * out_h * out_w + + # compute parameters + params = sum([p.numel() for p in self.parameters()]) + + output = torch.zeros( + size=(b, self.out_channels, out_h, out_w), + dtype=input.dtype, + device=input.device, + ) + # print(macs) + return output, params, macs + + +class TransposeConvLayer(BaseLayer): + """ + Applies a 2D Transpose convolution (aka as Deconvolution) over an input + + Args: + opts: command line arguments + in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})` + out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})` + kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution. + stride (Union[int, Tuple[int, int]]): Stride for convolution. Default: 1 + dilation (Union[int, Tuple[int, int]]): Dilation rate for convolution. Default: 1 + groups (Optional[int]): Number of groups in convolution. Default: 1 + bias (Optional[bool]): Use bias. Default: ``False`` + padding_mode (Optional[str]): Padding mode. Default: ``zeros`` + use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True`` + use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization). + Default: ``True`` + padding (Optional[Union[int, Tuple]]): Padding will be done on both sides of each dimension in the input + output_padding (Optional[Union[int, Tuple]]): Additional padding on the output tensor + auto_padding (Optional[bool]): Compute padding automatically. Default: ``True`` + + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` + """ + + def __init__( + self, + opts, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple], + stride: Optional[Union[int, Tuple]] = 1, + dilation: Optional[Union[int, Tuple]] = 1, + groups: Optional[int] = 1, + bias: Optional[bool] = False, + padding_mode: Optional[str] = "zeros", + use_norm: Optional[bool] = True, + use_act: Optional[bool] = True, + padding: Optional[Union[int, Tuple]] = (0, 0), + output_padding: Optional[Union[int, Tuple]] = None, + auto_padding: Optional[bool] = True, + *args, + **kwargs + ): + super().__init__() + + if use_norm: + assert not bias, "Do not use bias when using normalization layers." + + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + + if isinstance(stride, int): + stride = (stride, stride) + + if isinstance(dilation, int): + dilation = (dilation, dilation) + + if output_padding is None: + output_padding = (stride[0] - 1, stride[1] - 1) + + assert isinstance(kernel_size, (tuple, list)) + assert isinstance(stride, (tuple, list)) + assert isinstance(dilation, (tuple, list)) + + if auto_padding: + padding = ( + int((kernel_size[0] - 1) / 2) * dilation[0], + int((kernel_size[1] - 1) / 2) * dilation[1], + ) + + if in_channels % groups != 0: + logger.error( + "Input channels are not divisible by groups. {}%{} != 0 ".format( + in_channels, groups + ) + ) + if out_channels % groups != 0: + logger.error( + "Output channels are not divisible by groups. {}%{} != 0 ".format( + out_channels, groups + ) + ) + + block = nn.Sequential() + conv_layer = nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + output_padding=output_padding, + ) + + block.add_module(name="conv", module=conv_layer) + + self.norm_name = None + if use_norm: + norm_layer = get_normalization_layer(opts=opts, num_features=out_channels) + block.add_module(name="norm", module=norm_layer) + self.norm_name = norm_layer.__class__.__name__ + + self.act_name = None + act_type = getattr(opts, "model.activation.name", "relu") + + if act_type is not None and use_act: + neg_slope = getattr(opts, "model.activation.neg_slope", 0.1) + inplace = getattr(opts, "model.activation.inplace", False) + act_layer = get_activation_fn( + act_type=act_type, + inplace=inplace, + negative_slope=neg_slope, + num_parameters=out_channels, + ) + block.add_module(name="act", module=act_layer) + self.act_name = act_layer.__class__.__name__ + + self.block = block + + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.groups = groups + self.kernel_size = conv_layer.kernel_size + self.bias = bias + + def forward(self, x: Tensor) -> Tensor: + return self.block(x) + + def __repr__(self): + repr_str = self.block[0].__repr__() + repr_str = repr_str[:-1] + + if self.norm_name is not None: + repr_str += ", normalization={}".format(self.norm_name) + + if self.act_name is not None: + repr_str += ", activation={}".format(self.act_name) + repr_str += ")" + return repr_str + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + if input.dim() != 4: + logger.error( + "Conv2d requires 4-dimensional input (BxCxHxW). Provided input has shape: {}".format( + input.size() + ) + ) + + b, in_c, in_h, in_w = input.size() + assert in_c == self.in_channels, "{}!={}".format(in_c, self.in_channels) + + stride_h, stride_w = self.stride + groups = self.groups + + out_h = in_h * stride_h + out_w = in_w * stride_w + + k_h, k_w = self.kernel_size + + # compute MACS + macs = (k_h * k_w) * (in_c * self.out_channels) * (out_h * out_w) * 1.0 + macs /= groups + + if self.bias: + macs += self.out_channels * out_h * out_w + + # compute parameters + params = sum([p.numel() for p in self.parameters()]) + + output = torch.zeros( + size=(b, self.out_channels, out_h, out_w), + dtype=input.dtype, + device=input.device, + ) + # print(macs) + return output, params, macs + + +class NormActLayer(BaseLayer): + """ + Applies a normalization layer followed by an activation layer + + Args: + opts: command-line arguments + num_features: :math:`C` from an expected input of size :math:`(N, C, H, W)` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` + """ + + def __init__(self, opts, num_features, *args, **kwargs): + super().__init__() + + block = nn.Sequential() + + self.norm_name = None + norm_layer = get_normalization_layer(opts=opts, num_features=num_features) + block.add_module(name="norm", module=norm_layer) + self.norm_name = norm_layer.__class__.__name__ + + self.act_name = None + act_type = getattr(opts, "model.activation.name", "prelu") + neg_slope = getattr(opts, "model.activation.neg_slope", 0.1) + inplace = getattr(opts, "model.activation.inplace", False) + act_layer = get_activation_fn( + act_type=act_type, + inplace=inplace, + negative_slope=neg_slope, + num_parameters=num_features, + ) + block.add_module(name="act", module=act_layer) + self.act_name = act_layer.__class__.__name__ + + self.block = block + + def forward(self, x: Tensor) -> Tensor: + return self.block(x) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + # compute parameters + params = sum([p.numel() for p in self.parameters()]) + macs = 0.0 + return input, params, macs + + def __repr__(self): + repr_str = "{}(normalization={}, activation={})".format( + self.__class__.__name__, self.norm_type, self.act_type + ) + return repr_str + + +class ConvLayer3d(BaseLayer): + """ + Applies a 3D convolution over an input + + Args: + opts: command line arguments + in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` + out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` + kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution. + stride (Union[int, Tuple[int, int]]): Stride for convolution. Default: 1 + dilation (Union[int, Tuple[int, int]]): Dilation rate for convolution. Default: 1 + groups (Optional[int]): Number of groups in convolution. Default: 1 + bias (Optional[bool]): Use bias. Default: ``False`` + padding_mode (Optional[str]): Padding mode. Default: ``zeros`` + use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True`` + use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization). + Default: ``True`` + + Shape: + - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` + + .. note:: + For depth-wise convolution, `groups=C_{in}=C_{out}`. + """ + + def __init__( + self, + opts, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple], + stride: Optional[Union[int, Tuple]] = 1, + dilation: Optional[Union[int, Tuple]] = 1, + groups: Optional[int] = 1, + bias: Optional[bool] = False, + padding_mode: Optional[str] = "zeros", + use_norm: Optional[bool] = True, + use_act: Optional[bool] = True, + *args, + **kwargs + ) -> None: + super().__init__() + + if use_norm: + assert not bias, "Do not use bias when using normalization layers." + + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + + if isinstance(stride, int): + stride = (stride, stride, stride) + + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) + + assert isinstance(kernel_size, (tuple, list)) + assert isinstance(stride, (tuple, list)) + assert isinstance(dilation, (tuple, list)) + + padding = tuple([int((kernel_size[i] - 1) / 2) * dilation[i] for i in range(3)]) + + if in_channels % groups != 0: + logger.error( + "Input channels are not divisible by groups. {}%{} != 0 ".format( + in_channels, groups + ) + ) + if out_channels % groups != 0: + logger.error( + "Output channels are not divisible by groups. {}%{} != 0 ".format( + out_channels, groups + ) + ) + + block = nn.Sequential() + + conv_layer = nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + ) + + block.add_module(name="conv", module=conv_layer) + + self.norm_name = None + norm_type = getattr(opts, "model.normalization.name", "batch_norm") + if use_norm and norm_type is not None: + if norm_type.find("batch") > -1: + norm_type = "batch_norm_3d" + norm_layer = get_normalization_layer( + opts=opts, num_features=out_channels, norm_type=norm_type + ) + block.add_module(name="norm", module=norm_layer) + self.norm_name = norm_layer.__class__.__name__ + + self.act_name = None + act_type = getattr(opts, "model.activation.name", "prelu") + + if act_type is not None and use_act: + neg_slope = getattr(opts, "model.activation.neg_slope", 0.1) + inplace = getattr(opts, "model.activation.inplace", False) + act_layer = get_activation_fn( + act_type=act_type, + inplace=inplace, + negative_slope=neg_slope, + num_parameters=out_channels, + ) + block.add_module(name="act", module=act_layer) + self.act_name = act_layer.__class__.__name__ + + self.block = block + + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.groups = groups + self.kernel_size = conv_layer.kernel_size + self.bias = bias + self.dilation = dilation + + def forward(self, x: Tensor) -> Tensor: + return self.block(x) + + def __repr__(self): + repr_str = self.block[0].__repr__() + repr_str = repr_str[:-1] + + if self.norm_name is not None: + repr_str += ", normalization={}".format(self.norm_name) + + if self.act_name is not None: + repr_str += ", activation={}".format(self.act_name) + repr_str += ")" + return repr_str + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + if input.dim() != 4: + logger.error( + "Conv2d requires 4-dimensional input (BxCxHxW). Provided input has shape: {}".format( + input.size() + ) + ) + + b, in_c, in_d, in_h, in_w = input.size() + assert in_c == self.in_channels, "{}!={}".format(in_c, self.in_channels) + + stride_d, stride_h, stride_w = self.stride + groups = self.groups + + out_h = in_h // stride_h + out_w = in_w // stride_w + out_d = in_d // stride_d + + k_d, k_h, k_w = self.kernel_size + + # compute MACS + macs = ( + (k_d * k_h * k_w) + * (in_c * self.out_channels) + * (out_h * out_w * out_d) + * 1.0 + ) + macs /= groups + + if self.bias: + macs += self.out_channels * out_h * out_w * out_d + + # compute parameters + params = sum([p.numel() for p in self.parameters()]) + + output = torch.zeros( + size=(b, self.out_channels, out_d, out_h, out_w), + dtype=input.dtype, + device=input.device, + ) + return output, params, macs + + +class SeparableConv(BaseLayer): + """ + Applies a `2D depth-wise separable convolution `_ over a 4D input tensor + + Args: + opts: command line arguments + in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})` + out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})` + kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution. + stride (Union[int, Tuple[int, int]]): Stride for convolution. Default: 1 + dilation (Union[int, Tuple[int, int]]): Dilation rate for convolution. Default: 1 + use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True`` + use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization). Default: ``True`` + bias (Optional[bool]): Use bias. Default: ``False`` + padding_mode (Optional[str]): Padding mode. Default: ``zeros`` + + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` + + .. note:: + For depth-wise convolution, `groups=C_{in}=C_{out}`. + """ + + def __init__( + self, + opts, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple], + stride: Optional[Union[int, Tuple]] = 1, + dilation: Optional[Union[int, Tuple]] = 1, + use_norm: Optional[bool] = True, + use_act: Optional[bool] = True, + bias: Optional[bool] = False, + padding_mode: Optional[str] = "zeros", + *args, + **kwargs + ) -> None: + super().__init__() + self.dw_conv = ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=in_channels, + bias=False, + padding_mode=padding_mode, + use_norm=True, + use_act=False, + ) + self.pw_conv = ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + dilation=1, + groups=1, + bias=bias, + padding_mode=padding_mode, + use_norm=use_norm, + use_act=use_act, + ) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.kernel_size = kernel_size + self.dilation = dilation + + def __repr__(self): + repr_str = "{}(in_channels={}, out_channels={}, kernel_size={}, stride={}, dilation={})".format( + self.__class__.__name__, + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.dilation, + ) + return repr_str + + def forward(self, x: Tensor) -> Tensor: + x = self.dw_conv(x) + x = self.pw_conv(x) + return x + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + params, macs = 0.0, 0.0 + input, p, m = self.dw_conv.profile_module(input) + params += p + macs += m + + input, p, m = self.pw_conv.profile_module(input) + params += p + macs += m + + return input, params, macs diff --git a/Adaptive Frequency Filters/affnet/layers/conv_layer_complex.py b/Adaptive Frequency Filters/affnet/layers/conv_layer_complex.py new file mode 100644 index 0000000..71e8470 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/conv_layer_complex.py @@ -0,0 +1,281 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +from typing import Optional, Union, Tuple +import argparse + +from utils import logger + +from .base_layer import BaseLayer +from .normalization_layers import get_normalization_layer, get_complex_normalization_layer +from .non_linear_layers import get_activation_fn, get_complex_activation_fn + +from complexPyTorch.complexLayers import ComplexBatchNorm2d, ComplexConv2d, ComplexLinear + + + +class Conv2d(ComplexConv2d): + """ + Applies a 2D convolution over an input + + Args: + in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})` + out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})` + kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution. + stride (Union[int, Tuple[int, int]]): Stride for convolution. Defaults to 1 + padding (Union[int, Tuple[int, int]]): Padding for convolution. Defaults to 0 + dilation (Union[int, Tuple[int, int]]): Dilation rate for convolution. Default: 1 + groups (Optional[int]): Number of groups in convolution. Default: 1 + bias (bool): Use bias. Default: ``False`` + padding_mode (Optional[str]): Padding mode. Default: ``zeros`` + + use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True`` + use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization). + Default: ``True`` + act_name (Optional[str]): Use specific activation function. Overrides the one specified in command line args. + + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = 1, + padding: Optional[Union[int, Tuple[int, int]]] = 0, + dilation: Optional[Union[int, Tuple[int, int]]] = 1, + groups: Optional[int] = 1, + bias: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + +class ConvLayerComplex(BaseLayer): + """ + Applies a 2D convolution over an input + + Args: + opts: command line arguments + in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})` + out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})` + kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution. + stride (Union[int, Tuple[int, int]]): Stride for convolution. Default: 1 + dilation (Union[int, Tuple[int, int]]): Dilation rate for convolution. Default: 1 + padding (Union[int, Tuple[int, int]]): Padding for convolution. When not specified, + padding is automatically computed based on kernel size + and dilation rage. Default is ``None`` + groups (Optional[int]): Number of groups in convolution. Default: ``1`` + bias (Optional[bool]): Use bias. Default: ``False`` + padding_mode (Optional[str]): Padding mode. Default: ``zeros`` + use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True`` + use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization). + Default: ``True`` + act_name (Optional[str]): Use specific activation function. Overrides the one specified in command line args. + + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` + + .. note:: + For depth-wise convolution, `groups=C_{in}=C_{out}`. + """ + + def __init__( + self, + opts, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Optional[Union[int, Tuple[int, int]]] = 1, + dilation: Optional[Union[int, Tuple[int, int]]] = 1, + padding: Optional[Union[int, Tuple[int, int]]] = None, + groups: Optional[int] = 1, + bias: Optional[bool] = False, + padding_mode: Optional[str] = "zeros", + use_norm: Optional[bool] = True, + use_act: Optional[bool] = True, + act_name: Optional[str] = None, + *args, + **kwargs + ) -> None: + super().__init__() + + if use_norm: + norm_type = getattr(opts, "model.normalization.name", "batch_norm") + if norm_type is not None and norm_type.find("batch") > -1: + assert not bias, "Do not use bias when using normalization layers." + elif norm_type is not None and norm_type.find("layer") > -1: + bias = True + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + + if isinstance(stride, int): + stride = (stride, stride) + + if isinstance(dilation, int): + dilation = (dilation, dilation) + + assert isinstance(kernel_size, Tuple) + assert isinstance(stride, Tuple) + assert isinstance(dilation, Tuple) + + if padding is None: + padding = ( + int((kernel_size[0] - 1) / 2) * dilation[0], + int((kernel_size[1] - 1) / 2) * dilation[1], + ) + + if in_channels % groups != 0: + logger.error( + "Input channels are not divisible by groups. {}%{} != 0 ".format( + in_channels, groups + ) + ) + if out_channels % groups != 0: + logger.error( + "Output channels are not divisible by groups. {}%{} != 0 ".format( + out_channels, groups + ) + ) + + block = nn.Sequential() + + conv_layer = Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + ) + + block.add_module(name="conv", module=conv_layer) + + self.norm_name = None + if use_norm: + norm_layer = get_complex_normalization_layer(opts=opts, num_features=out_channels) + block.add_module(name="norm", module=norm_layer) + self.norm_name = norm_layer.bn_r.__class__.__name__ + + self.act_name = None + act_type = ( + getattr(opts, "model.activation.name", "prelu") + if act_name is None + else act_name + ) + + if act_type is not None and use_act: + neg_slope = getattr(opts, "model.activation.neg_slope", 0.1) + inplace = getattr(opts, "model.activation.inplace", False) + act_layer = get_complex_activation_fn( + act_type=act_type, + inplace=inplace, + negative_slope=neg_slope, + num_parameters=out_channels, + ) + block.add_module(name="act", module=act_layer) + self.act_name = act_layer.act_r.__class__.__name__ + + self.block = block + + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.groups = groups + self.kernel_size = conv_layer.conv_r.kernel_size + self.bias = bias + self.dilation = dilation + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + cls_name = "{} arguments".format(cls.__name__) + group = parser.add_argument_group(title=cls_name, description=cls_name) + group.add_argument( + "--model.layer.conv-complex-init", + type=str, + default="kaiming_normal", + help="Init type for conv layers", + ) + parser.add_argument( + "--model.layer.conv-complex-init-std-dev", + type=float, + default=None, + help="Std deviation for conv layers", + ) + return parser + + def forward(self, x: Tensor) -> Tensor: + return self.block(x) + + def __repr__(self): + repr_str = self.block[0].__repr__() + repr_str = repr_str[:-1] + + if self.norm_name is not None: + repr_str += ", normalization={}".format(self.norm_name) + + if self.act_name is not None: + repr_str += ", activation={}".format(self.act_name) + repr_str += ")" + return repr_str + + def profile_module(self, input: Tensor) -> (Tensor, float, float): + if input.dim() != 4: + logger.error( + "Conv2d requires 4-dimensional input (BxCxHxW). Provided input has shape: {}".format( + input.size() + ) + ) + + b, in_c, in_h, in_w = input.size() + assert in_c == self.in_channels, "{}!={}".format(in_c, self.in_channels) + + stride_h, stride_w = self.stride + groups = self.groups + + out_h = in_h // stride_h + out_w = in_w // stride_w + + k_h, k_w = self.kernel_size + + # compute MACS + macs = (k_h * k_w) * (in_c * self.out_channels) * (out_h * out_w) * 1.0 + macs /= groups + + if self.bias: + macs += self.out_channels * out_h * out_w + + # compute parameters + params = sum([p.numel() for p in self.parameters()]) + + output = torch.zeros( + size=(b, self.out_channels, out_h, out_w), + dtype=input.dtype, + device=input.device, + ) + # print(macs) + return output, params, macs + + diff --git a/Adaptive Frequency Filters/affnet/layers/dropout.py b/Adaptive Frequency Filters/affnet/layers/dropout.py new file mode 100644 index 0000000..8421d34 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/dropout.py @@ -0,0 +1,57 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Optional, Tuple + + +class Dropout(nn.Dropout): + """ + This layer, during training, randomly zeroes some of the elements of the input tensor with probability `p` + using samples from a Bernoulli distribution. + + Args: + p: probability of an element to be zeroed. Default: 0.5 + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(N, *)` where :math:`N` is the batch size + - Output: same as the input + + """ + + def __init__( + self, p: Optional[float] = 0.5, inplace: Optional[bool] = False, *args, **kwargs + ) -> None: + super().__init__(p=p, inplace=inplace) + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + return input, 0.0, 0.0 + + +class Dropout2d(nn.Dropout2d): + """ + This layer, during training, randomly zeroes some of the elements of the 4D input tensor with probability `p` + using samples from a Bernoulli distribution. + + Args: + p: probability of an element to be zeroed. Default: 0.5 + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(N, C, H, W)` where :math:`N` is the batch size, :math:`C` is the input channels, + :math:`H` is the input tensor height, and :math:`W` is the input tensor width + - Output: same as the input + + """ + + def __init__(self, p: float = 0.5, inplace: bool = False): + super().__init__(p=p, inplace=inplace) + + def profile_module(self, input: Tensor, *args, **kwargs) -> (Tensor, float, float): + return input, 0.0, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/embedding.py b/Adaptive Frequency Filters/affnet/layers/embedding.py new file mode 100644 index 0000000..b1d0dc3 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/embedding.py @@ -0,0 +1,57 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +from typing import Optional, Union, Tuple +import argparse + +from utils import logger + +from .base_layer import BaseLayer +from .normalization_layers import get_normalization_layer +from .non_linear_layers import get_activation_fn + + +class Embedding(nn.Embedding): + """A lookup table that stores embeddings of a fixed dictionary and size. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; + therefore, the embedding vector at :attr:`padding_idx` is not updated during training, + i.e. it remains as a fixed "pad". For a newly constructed Embedding, + the embedding vector at :attr:`padding_idx` will default to all zeros, + but can be updated to another value to be used as the padding vector. + + Shape: + - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + """ + + def __init__( + self, + opts, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + *args, + **kwargs + ): + super().__init__( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + ) + + def reset_parameters(self) -> None: + nn.init.normal_(self.weight, mean=0, std=self.embedding_dim**-0.5) + if self.padding_idx is not None: + nn.init.constant_(self.weight[self.padding_idx], 0) + + def profile_module(self, input: Tensor, *args, **kwargs) -> (Tensor, float, float): + return input, 0.0, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/flatten.py b/Adaptive Frequency Filters/affnet/layers/flatten.py new file mode 100644 index 0000000..8ddc780 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/flatten.py @@ -0,0 +1,31 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Tuple, Optional + + +class Flatten(nn.Flatten): + """ + This layer flattens a contiguous range of dimensions into a tensor. + + Args: + start_dim (Optional[int]): first dim to flatten. Default: 1 + end_dim (Optional[int]): last dim to flatten. Default: -1 + + Shape: + - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,' + where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any + number of dimensions including none. + - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`. + """ + + def __init__(self, start_dim: Optional[int] = 1, end_dim: Optional[int] = -1): + super(Flatten, self).__init__(start_dim=start_dim, end_dim=end_dim) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + input = self.forward(input) + return input, 0.0, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/global_pool.py b/Adaptive Frequency Filters/affnet/layers/global_pool.py new file mode 100644 index 0000000..d69452b --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/global_pool.py @@ -0,0 +1,88 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import Tensor +import argparse +from typing import List, Optional, Tuple + +from utils import logger + +from .base_layer import BaseLayer + + +class GlobalPool(BaseLayer): + """ + This layers applies global pooling over a 4D or 5D input tensor + + Args: + pool_type (Optional[str]): Pooling type. It can be mean, rms, or abs. Default: `mean` + keep_dim (Optional[bool]): Do not squeeze the dimensions of a tensor. Default: `False` + + Shape: + - Input: :math:`(N, C, H, W)` or :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, 1, 1)` or :math:`(N, C, 1, 1, 1)` if keep_dim else :math:`(N, C)` + """ + + pool_types = ["mean", "rms", "abs"] + + def __init__( + self, + pool_type: Optional[str] = "mean", + keep_dim: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__() + if pool_type not in self.pool_types: + logger.error( + "Supported pool types are: {}. Got {}".format( + self.pool_types, pool_type + ) + ) + self.pool_type = pool_type + self.keep_dim = keep_dim + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + cls_name = "{} arguments".format(cls.__name__) + group = parser.add_argument_group(title=cls_name, description=cls_name) + group.add_argument( + "--model.layer.global-pool", + type=str, + default="mean", + help="Which global pooling?", + ) + return parser + + def _global_pool(self, x: Tensor, dims: List): + if self.pool_type == "rms": # root mean square + x = x**2 + x = torch.mean(x, dim=dims, keepdim=self.keep_dim) + x = x**-0.5 + elif self.pool_type == "abs": # absolute + x = torch.mean(torch.abs(x), dim=dims, keepdim=self.keep_dim) + else: + # default is mean + # same as AdaptiveAvgPool + x = torch.mean(x, dim=dims, keepdim=self.keep_dim) + return x + + def forward(self, x: Tensor) -> Tensor: + if x.dim() == 4: + dims = [-2, -1] + elif x.dim() == 5: + dims = [-3, -2, -1] + else: + raise NotImplementedError("Currently 2D and 3D global pooling supported") + return self._global_pool(x, dims=dims) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + input = self.forward(input) + return input, 0.0, 0.0 + + def __repr__(self): + return "{}(type={})".format(self.__class__.__name__, self.pool_type) diff --git a/Adaptive Frequency Filters/affnet/layers/identity.py b/Adaptive Frequency Filters/affnet/layers/identity.py new file mode 100644 index 0000000..3b60196 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/identity.py @@ -0,0 +1,25 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import Tensor +from typing import Tuple + +from .base_layer import BaseLayer + + +class Identity(BaseLayer): + """ + This is a place-holder and returns the same tensor. + """ + + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x: Tensor) -> Tensor: + return x + + def profile_module(self, x: Tensor) -> Tuple[Tensor, float, float]: + return x, 0.0, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/linear_attention.py b/Adaptive Frequency Filters/affnet/layers/linear_attention.py new file mode 100644 index 0000000..e43249d --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/linear_attention.py @@ -0,0 +1,233 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- +import torch +from torch import Tensor +from typing import Optional, Tuple +from torch.nn import functional as F + +from .base_layer import BaseLayer +from .conv_layer import ConvLayer +from .dropout import Dropout +from ..misc.profiler import module_profile + + +class LinearSelfAttention(BaseLayer): + """ + This layer applies a self-attention with linear complexity, as described in `MobileViTv2 `_ paper. + This layer can be used for self- as well as cross-attention. + + Args: + opts: command line arguments + embed_dim (int): :math:`C` from an expected input of size :math:`(N, C, H, W)` + attn_dropout (Optional[float]): Dropout value for context scores. Default: 0.0 + bias (Optional[bool]): Use bias in learnable layers. Default: True + + Shape: + - Input: :math:`(N, C, P, N)` where :math:`N` is the batch size, :math:`C` is the input channels, + :math:`P` is the number of pixels in the patch, and :math:`N` is the number of patches + - Output: same as the input + + .. note:: + For MobileViTv2, we unfold the feature map [B, C, H, W] into [B, C, P, N] where P is the number of pixels + in a patch and N is the number of patches. Because channel is the first dimension in this unfolded tensor, + we use point-wise convolution (instead of a linear layer). This avoids a transpose operation (which may be + expensive on resource-constrained devices) that may be required to convert the unfolded tensor from + channel-first to channel-last format in case of a linear layer. + """ + + def __init__( + self, + opts, + embed_dim: int, + attn_dropout: Optional[float] = 0.0, + bias: Optional[bool] = True, + *args, + **kwargs + ) -> None: + super().__init__() + + self.qkv_proj = ConvLayer( + opts=opts, + in_channels=embed_dim, + out_channels=1 + (2 * embed_dim), + bias=bias, + kernel_size=1, + use_norm=False, + use_act=False, + ) + + self.attn_dropout = Dropout(p=attn_dropout) + self.out_proj = ConvLayer( + opts=opts, + in_channels=embed_dim, + out_channels=embed_dim, + bias=bias, + kernel_size=1, + use_norm=False, + use_act=False, + ) + self.embed_dim = embed_dim + + def __repr__(self): + return "{}(embed_dim={}, attn_dropout={})".format( + self.__class__.__name__, self.embed_dim, self.attn_dropout.p + ) + + @staticmethod + def visualize_context_scores(context_scores): + # [B, 1, P, N] + batch_size, channels, num_pixels, num_patches = context_scores.shape + + assert batch_size == 1, "For visualization purposes, use batch size of 1" + assert ( + channels == 1 + ), "The inner-product between input and latent node (query) is a scalar" + + up_scale_factor = int(num_pixels**0.5) + patch_h = patch_w = int(context_scores.shape[-1] ** 0.5) + # [1, 1, P, N] --> [1, P, h, w] + context_scores = context_scores.reshape(1, num_pixels, patch_h, patch_w) + # Fold context scores [1, P, h, w] using pixel shuffle to obtain [1, 1, H, W] + context_map = F.pixel_shuffle(context_scores, upscale_factor=up_scale_factor) + # [1, 1, H, W] --> [H, W] + context_map = context_map.squeeze() + + # For ease of visualization, we do min-max normalization + min_val = torch.min(context_map) + max_val = torch.max(context_map) + context_map = (context_map - min_val) / (max_val - min_val) + + try: + import cv2 + from glob import glob + import os + + # convert from float to byte + context_map = (context_map * 255).byte().cpu().numpy() + context_map = cv2.resize( + context_map, (80, 80), interpolation=cv2.INTER_NEAREST + ) + + colored_context_map = cv2.applyColorMap(context_map, cv2.COLORMAP_JET) + # Lazy way to dump feature maps in attn_res folder. Make sure that directory is empty and copy + # context maps before running on different image. Otherwise, attention maps will be overridden. + res_dir_name = "attn_res" + if not os.path.isdir(res_dir_name): + os.makedirs(res_dir_name) + f_name = "{}/h_{}_w_{}_index_".format(res_dir_name, patch_h, patch_w) + + files_cmap = glob( + "{}/h_{}_w_{}_index_*.png".format(res_dir_name, patch_h, patch_w) + ) + idx = len(files_cmap) + f_name += str(idx) + + cv2.imwrite("{}.png".format(f_name), colored_context_map) + return colored_context_map + except ModuleNotFoundError as mnfe: + print("Please install OpenCV to visualize context maps") + return context_map + + def _forward_self_attn(self, x: Tensor, *args, **kwargs) -> Tensor: + # [B, C, P, N] --> [B, h + 2d, P, N] + qkv = self.qkv_proj(x) + + # Project x into query, key and value + # Query --> [B, 1, P, N] + # value, key --> [B, d, P, N] + query, key, value = torch.split( + qkv, split_size_or_sections=[1, self.embed_dim, self.embed_dim], dim=1 + ) + + # apply softmax along N dimension + context_scores = F.softmax(query, dim=-1) + # Uncomment below line to visualize context scores + # self.visualize_context_scores(context_scores=context_scores) + context_scores = self.attn_dropout(context_scores) + + # Compute context vector + # [B, d, P, N] x [B, 1, P, N] -> [B, d, P, N] + context_vector = key * context_scores + # [B, d, P, N] --> [B, d, P, 1] + context_vector = torch.sum(context_vector, dim=-1, keepdim=True) + + # combine context vector with values + # [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N] + out = F.relu(value) * context_vector.expand_as(value) + out = self.out_proj(out) + return out + + def _forward_cross_attn( + self, x: Tensor, x_prev: Optional[Tensor] = None, *args, **kwargs + ) -> Tensor: + # x --> [B, C, P, N] + # x_prev = [B, C, P, M] + + batch_size, in_dim, kv_patch_area, kv_num_patches = x.shape + + q_patch_area, q_num_patches = x.shape[-2:] + + assert ( + kv_patch_area == q_patch_area + ), "The number of pixels in a patch for query and key_value should be the same" + + # compute query, key, and value + # [B, C, P, M] --> [B, 1 + d, P, M] + qk = F.conv2d( + x_prev, + weight=self.qkv_proj.block.conv.weight[: self.embed_dim + 1, ...], + bias=self.qkv_proj.block.conv.bias[: self.embed_dim + 1, ...], + ) + # [B, 1 + d, P, M] --> [B, 1, P, M], [B, d, P, M] + query, key = torch.split(qk, split_size_or_sections=[1, self.embed_dim], dim=1) + # [B, C, P, N] --> [B, d, P, N] + value = F.conv2d( + x, + weight=self.qkv_proj.block.conv.weight[self.embed_dim + 1 :, ...], + bias=self.qkv_proj.block.conv.bias[self.embed_dim + 1 :, ...], + ) + + # apply softmax along M dimension + context_scores = F.softmax(query, dim=-1) + context_scores = self.attn_dropout(context_scores) + + # compute context vector + # [B, d, P, M] * [B, 1, P, M] -> [B, d, P, M] + context_vector = key * context_scores + # [B, d, P, M] --> [B, d, P, 1] + context_vector = torch.sum(context_vector, dim=-1, keepdim=True) + + # combine context vector with values + # [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N] + out = F.relu(value) * context_vector.expand_as(value) + out = self.out_proj(out) + return out + + def forward( + self, x: Tensor, x_prev: Optional[Tensor] = None, *args, **kwargs + ) -> Tensor: + if x_prev is None: + return self._forward_self_attn(x, *args, **kwargs) + else: + return self._forward_cross_attn(x, x_prev=x_prev, *args, **kwargs) + + def profile_module(self, input) -> Tuple[Tensor, float, float]: + params = macs = 0.0 + + qkv, p, m = module_profile(module=self.qkv_proj, x=input) + params += p + macs += m + + query, key, value = torch.split( + qkv, split_size_or_sections=[1, self.embed_dim, self.embed_dim], dim=1 + ) + + if self.out_proj is not None: + out_p, p, m = module_profile(module=self.out_proj, x=value) + params += p + macs += m + + return input, params, macs diff --git a/Adaptive Frequency Filters/affnet/layers/linear_layer.py b/Adaptive Frequency Filters/affnet/layers/linear_layer.py new file mode 100644 index 0000000..6eef8e7 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/linear_layer.py @@ -0,0 +1,253 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +from typing import Optional, Tuple +import argparse +from torch.nn import functional as F + +from utils import logger + +from .base_layer import BaseLayer + + +class LinearLayer(BaseLayer): + """ + Applies a linear transformation to the input data + + Args: + in_features (int): number of features in the input tensor + out_features (int): number of features in the output tensor + bias (Optional[bool]): use bias or not + channel_first (Optional[bool]): Channels are first or last dimension. If first, then use Conv2d + + Shape: + - Input: :math:`(N, *, C_{in})` if not channel_first else :math:`(N, C_{in}, *)` where :math:`*` means any number of dimensions. + - Output: :math:`(N, *, C_{out})` if not channel_first else :math:`(N, C_{out}, *)` + + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: Optional[bool] = True, + channel_first: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) + self.bias = nn.Parameter(torch.Tensor(out_features)) if bias else None + + self.in_features = in_features + self.out_features = out_features + self.channel_first = channel_first + + self.reset_params() + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + parser.add_argument( + "--model.layer.linear-init", + type=str, + default="xavier_uniform", + help="Init type for linear layers", + ) + parser.add_argument( + "--model.layer.linear-init-std-dev", + type=float, + default=0.01, + help="Std deviation for Linear layers", + ) + return parser + + def reset_params(self): + if self.weight is not None: + torch.nn.init.xavier_uniform_(self.weight) + if self.bias is not None: + torch.nn.init.constant_(self.bias, 0) + + def forward(self, x: Tensor) -> Tensor: + if self.channel_first: + if not self.training: + logger.error("Channel-first mode is only supported during inference") + if x.dim() != 4: + logger.error("Input should be 4D, i.e., (B, C, H, W) format") + # only run during conversion + with torch.no_grad(): + return F.conv2d( + input=x, + weight=self.weight.clone() + .detach() + .reshape(self.out_features, self.in_features, 1, 1), + bias=self.bias, + ) + else: + x = F.linear(x, weight=self.weight, bias=self.bias) + return x + + def __repr__(self): + repr_str = ( + "{}(in_features={}, out_features={}, bias={}, channel_first={})".format( + self.__class__.__name__, + self.in_features, + self.out_features, + True if self.bias is not None else False, + self.channel_first, + ) + ) + return repr_str + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + out_size = list(input.shape) + out_size[-1] = self.out_features + params = sum([p.numel() for p in self.parameters()]) + macs = params + output = torch.zeros(size=out_size, dtype=input.dtype, device=input.device) + return output, params, macs + + +class GroupLinear(BaseLayer): + """ + Applies a GroupLinear transformation layer, as defined `here `_, + `here `_ and `here `_ + + Args: + in_features (int): number of features in the input tensor + out_features (int): number of features in the output tensor + n_groups (int): number of groups + bias (Optional[bool]): use bias or not + feature_shuffle (Optional[bool]): Shuffle features between groups + + Shape: + - Input: :math:`(N, *, C_{in})` + - Output: :math:`(N, *, C_{out})` + + """ + + def __init__( + self, + in_features: int, + out_features: int, + n_groups: int, + bias: Optional[bool] = True, + feature_shuffle: Optional[bool] = False, + *args, + **kwargs + ) -> None: + if in_features % n_groups != 0: + logger.error( + "Input dimensions ({}) must be divisible by n_groups ({})".format( + in_features, n_groups + ) + ) + if out_features % n_groups != 0: + logger.error( + "Output dimensions ({}) must be divisible by n_groups ({})".format( + out_features, n_groups + ) + ) + + in_groups = in_features // n_groups + out_groups = out_features // n_groups + + super().__init__() + + self.weight = nn.Parameter(torch.Tensor(n_groups, in_groups, out_groups)) + if bias: + self.bias = nn.Parameter(torch.Tensor(n_groups, 1, out_groups)) + else: + self.bias = None + + self.out_features = out_features + self.in_features = in_features + self.n_groups = n_groups + self.feature_shuffle = feature_shuffle + + self.reset_params() + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + parser.add_argument( + "--model.layer.group-linear-init", + type=str, + default="xavier_uniform", + help="Init type for group linear layers", + ) + parser.add_argument( + "--model.layer.group-linear-init-std-dev", + type=float, + default=0.01, + help="Std deviation for group linear layers", + ) + return parser + + def reset_params(self): + if self.weight is not None: + torch.nn.init.xavier_uniform_(self.weight.data) + if self.bias is not None: + torch.nn.init.constant_(self.bias.data, 0) + + def _forward(self, x: Tensor) -> Tensor: + bsz = x.shape[0] + # [B, N] --> [B, g, N/g] + x = x.reshape(bsz, self.n_groups, -1) + + # [B, g, N/g] --> [g, B, N/g] + x = x.transpose(0, 1) + # [g, B, N/g] x [g, N/g, M/g] --> [g, B, M/g] + x = torch.bmm(x, self.weight) + + if self.bias is not None: + x = torch.add(x, self.bias) + + if self.feature_shuffle: + # [g, B, M/g] --> [B, M/g, g] + x = x.permute(1, 2, 0) + # [B, M/g, g] --> [B, g, M/g] + x = x.reshape(bsz, self.n_groups, -1) + else: + # [g, B, M/g] --> [B, g, M/g] + x = x.transpose(0, 1) + + return x.reshape(bsz, -1) + + def forward(self, x: Tensor) -> Tensor: + if x.dim() == 2: + x = self._forward(x) + return x + else: + in_dims = x.shape[:-1] + n_elements = x.numel() // self.in_features + x = x.reshape(n_elements, -1) + x = self._forward(x) + x = x.reshape(*in_dims, -1) + return x + + def __repr__(self): + repr_str = "{}(in_features={}, out_features={}, groups={}, bias={}, shuffle={})".format( + self.__class__.__name__, + self.in_features, + self.out_features, + self.n_groups, + True if self.bias is not None else False, + self.feature_shuffle, + ) + return repr_str + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + params = sum([p.numel() for p in self.parameters()]) + macs = params + + out_size = list(input.shape) + out_size[-1] = self.out_features + + output = torch.zeros(size=out_size, dtype=input.dtype, device=input.device) + return output, params, macs diff --git a/Adaptive Frequency Filters/affnet/layers/multi_head_attention.py b/Adaptive Frequency Filters/affnet/layers/multi_head_attention.py new file mode 100644 index 0000000..c7470c0 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/multi_head_attention.py @@ -0,0 +1,333 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +from typing import Optional, Tuple +from torch.nn import functional as F + +from utils import logger + +from .base_layer import BaseLayer +from .linear_layer import LinearLayer +from .dropout import Dropout +from ..misc.profiler import module_profile + + +class MultiHeadAttention(BaseLayer): + """ + This layer applies a multi-head self- or cross-attention as described in + `Attention is all you need `_ paper + + Args: + embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, S, C_{in})` + num_heads (int): Number of heads in multi-head attention + attn_dropout (Optional[float]): Attention dropout. Default: 0.0 + bias (Optional[bool]): Use bias or not. Default: ``True`` + + Shape: + - Input: + - Query tensor (x_q) :math:`(N, S, C_{in})` where :math:`N` is batch size, :math:`S` is number of source tokens, + and :math:`C_{in}` is input embedding dim + - Optional Key-Value tensor (x_kv) :math:`(N, T, C_{in})` where :math:`T` is number of target tokens + - Output: same shape as the input + + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + attn_dropout: Optional[float] = 0.0, + bias: Optional[bool] = True, + output_dim: Optional[int] = None, + coreml_compatible: Optional[bool] = False, + *args, + **kwargs + ) -> None: + if output_dim is None: + output_dim = embed_dim + super().__init__() + if embed_dim % num_heads != 0: + logger.error( + "Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format( + self.__class__.__name__, embed_dim, num_heads + ) + ) + + self.qkv_proj = LinearLayer( + in_features=embed_dim, out_features=3 * embed_dim, bias=bias + ) + + self.attn_dropout = Dropout(p=attn_dropout) + self.out_proj = LinearLayer( + in_features=embed_dim, out_features=output_dim, bias=bias + ) + + self.head_dim = embed_dim // num_heads + self.scaling = self.head_dim**-0.5 + self.softmax = nn.Softmax(dim=-1) + self.num_heads = num_heads + self.embed_dim = embed_dim + self.coreml_compatible = coreml_compatible + self.use_separate_proj_weight = embed_dim != output_dim + + def __repr__(self): + return "{}(head_dim={}, num_heads={}, attn_dropout={})".format( + self.__class__.__name__, self.head_dim, self.num_heads, self.attn_dropout.p + ) + + def forward_tracing( + self, + x_q: Tensor, + x_kv: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tensor: + if x_kv is None: + # [N, S, C] --> # [N, S, 3C] Here, T=S + qkv = self.qkv_proj(x_q) + # # [N, S, 3C] --> # [N, S, C] x 3 + query, key, value = torch.chunk(qkv, chunks=3, dim=-1) + else: + # [N, S, C] + query = F.linear( + x_q, + weight=self.qkv_proj.weight[: self.embed_dim, ...], + bias=self.qkv_proj.bias[: self.embed_dim] + if self.qkv_proj.bias is not None + else None, + ) + + # [N, T, C] --> [N, T, 2C] + kv = F.linear( + x_kv, + weight=self.qkv_proj.weight[self.embed_dim :, ...], + bias=self.qkv_proj.bias[self.embed_dim :] + if self.qkv_proj.bias is not None + else None, + ) + key, value = torch.chunk(kv, chunks=2, dim=-1) + + query = query * self.scaling + + # [N, S, C] --> [N, S, c] x h, where C = c * h + query = torch.chunk(query, chunks=self.num_heads, dim=-1) + + # [N, T, C] --> [N, T, c] x h, where C = c * h + value = torch.chunk(value, chunks=self.num_heads, dim=-1) + # [N, T, C] --> [N, T, c] x h, where C = c * h + key = torch.chunk(key, chunks=self.num_heads, dim=-1) + + wt_out = [] + for h in range(self.num_heads): + attn_h = torch.matmul(query[h], key[h].transpose(-1, -2)) + attn_h = self.softmax(attn_h) + attn_h = self.attn_dropout(attn_h) + out_h = torch.matmul(attn_h, value[h]) + wt_out.append(out_h) + + wt_out = torch.cat(wt_out, dim=-1) + wt_out = self.out_proj(wt_out) + return wt_out + + def forward_default( + self, + x_q: Tensor, + x_kv: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tensor: + # [N, S, C] + b_sz, S_len, in_channels = x_q.shape + + if x_kv is None: + # self-attention + # [N, S, C] --> [N, S, 3C] --> [N, S, 3, h, c] where C = hc + qkv = self.qkv_proj(x_q).reshape(b_sz, S_len, 3, self.num_heads, -1) + # [N, S, 3, h, c] --> [N, h, 3, S, C] + qkv = qkv.transpose(1, 3).contiguous() + + # [N, h, 3, S, C] --> [N, h, S, C] x 3 + query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + else: + T_len = x_kv.shape[1] + + # cross-attention + # [N, S, C] + query = F.linear( + x_q, + weight=self.qkv_proj.weight[: self.embed_dim, ...], + bias=self.qkv_proj.bias[: self.embed_dim] + if self.qkv_proj.bias is not None + else None, + ) + # [N, S, C] --> [N, S, h, c] --> [N, h, S, c] + query = ( + query.reshape(b_sz, S_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + # [N, T, C] --> [N, T, 2C] + kv = F.linear( + x_kv, + weight=self.qkv_proj.weight[self.embed_dim :, ...], + bias=self.qkv_proj.bias[self.embed_dim :] + if self.qkv_proj.bias is not None + else None, + ) + # [N, T, 2C] --> [N, T, 2, h, c] + kv = kv.reshape(b_sz, T_len, 2, self.num_heads, self.head_dim) + # [N, T, 2, h, c] --> [N, h, 2, T, c] + kv = kv.transpose(1, 3).contiguous() + key, value = kv[:, :, 0], kv[:, :, 1] + + query = query * self.scaling + + # [N h, T, c] --> [N, h, c, T] + key = key.transpose(-1, -2) + + # QK^T + # [N, h, S, c] x [N, h, c, T] --> [N, h, S, T] + attn = torch.matmul(query, key) + + batch_size, num_heads, num_src_tokens, num_tgt_tokens = attn.shape + if attn_mask is not None: + # attn_mask shape should be the same as attn + assert list(attn_mask.shape) == [ + batch_size, + num_src_tokens, + num_tgt_tokens, + ], "Shape of attention mask should be [{}, {}, {}]. Got: {}".format( + batch_size, num_src_tokens, num_tgt_tokens, attn_mask.shape + ) + # [N, S, T] --> [N, 1, S, T] + attn_mask = attn_mask.unsqueeze(1) + attn = attn + attn_mask + + if key_padding_mask is not None: + # Do not attend to padding positions + # key padding mask size is [N, T] + assert key_padding_mask.dim() == 2 and list(key_padding_mask.shape) == [ + batch_size, + num_tgt_tokens, + ], "Key_padding_mask should be 2-dimension with shape [{}, {}]. Got: {}".format( + batch_size, num_tgt_tokens, key_padding_mask.shape + ) + attn = attn.masked_fill( + key_padding_mask.unsqueeze(1) + .unsqueeze(2) + .to(torch.bool), # [N, T] --> [N, 1, 1, T] + float("-inf"), + ) + + attn_dtype = attn.dtype + attn_as_float = self.softmax(attn.float()) + attn = attn_as_float.to(attn_dtype) + attn = self.attn_dropout(attn) + + # weighted sum + # [N, h, S, T] x [N, h, T, c] --> [N, h, S, c] + out = torch.matmul(attn, value) + + # [N, h, S, c] --> [N, S, h, c] --> [N, S, C] + out = out.transpose(1, 2).reshape(b_sz, S_len, -1) + out = self.out_proj(out) + + return out + + def forward_pytorch( + self, + x_q: Tensor, + x_kv: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tensor: + out, _ = F.multi_head_attention_forward( + query=x_q, + key=x_kv if x_kv is not None else x_q, + value=x_kv if x_kv is not None else x_q, + embed_dim_to_check=self.embed_dim, + num_heads=self.num_heads, + in_proj_weight=torch.empty([0]), + in_proj_bias=self.qkv_proj.bias, + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=self.attn_dropout.p, + out_proj_weight=self.out_proj.weight, + out_proj_bias=self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=False, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.qkv_proj.weight[: self.embed_dim, ...], + k_proj_weight=self.qkv_proj.weight[ + self.embed_dim : 2 * self.embed_dim, ... + ], + v_proj_weight=self.qkv_proj.weight[2 * self.embed_dim :, ...], + ) + return out + + def forward( + self, + x_q: Tensor, + x_kv: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + *args, + **kwargs + ) -> Tensor: + if self.coreml_compatible: + # For CoreML, we follow batch-first format. Make sure the input is of the form + # [Batch , Sequence, Hidden_dim] + return self.forward_tracing( + x_q=x_q, + x_kv=x_kv, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + ) + elif kwargs.get("use_pytorch_mha", False): + # pytorch uses sequence-first format. Make sure that input is of the form [Sequence, Batch, Hidden dim] + return self.forward_pytorch( + x_q=x_q, + x_kv=x_kv, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + ) + else: + # our default implementation format follows batch-first format. Make sure the input is of the form + # [Batch , Sequence, Hidden_dim] + return self.forward_default( + x_q=x_q, + x_kv=x_kv, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + ) + + def profile_module(self, input) -> Tuple[Tensor, float, float]: + b_sz, seq_len, in_channels = input.shape + params = macs = 0.0 + + qkv, p, m = module_profile(module=self.qkv_proj, x=input) + params += p + macs += m * seq_len * b_sz + + # number of operations in QK^T + m_qk = (seq_len * seq_len * in_channels) * b_sz + macs += m_qk + + # number of operations in computing weighted sum + m_wt = (seq_len * seq_len * in_channels) * b_sz + macs += m_wt + + out_p, p, m = module_profile(module=self.out_proj, x=input) + params += p + macs += m * seq_len * b_sz + + return input, params, macs diff --git a/Adaptive Frequency Filters/affnet/layers/non_linear_layers.py b/Adaptive Frequency Filters/affnet/layers/non_linear_layers.py new file mode 100644 index 0000000..3b83df1 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/non_linear_layers.py @@ -0,0 +1,77 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- +import torch +from torch import nn +from typing import Optional + +from .activation import build_activation_layer + + +def get_complex_activation_fn( + act_type: Optional[str] = "relu", + num_parameters: Optional[int] = -1, + inplace: Optional[bool] = True, + negative_slope: Optional[float] = 0.1, + *args, + **kwargs +) -> nn.Module: + """ + Helper function to get activation (or non-linear) function + """ + class ComplexAct(nn.Module): + ''' + Naive approach to complex batch norm, perform batch norm independently on real and imaginary part. + ''' + def __init__(self, act_type, num_parameters, negative_slope, inplace, *args, **kwargs): + super(ComplexAct, self).__init__() + self.act_r = build_activation_layer( + act_type=act_type, + num_parameters=num_parameters, + negative_slope=negative_slope, + inplace=inplace, + *args, + **kwargs + ) + self.act_i = build_activation_layer( + act_type=act_type, + num_parameters=num_parameters, + negative_slope=negative_slope, + inplace=inplace, + *args, + **kwargs + ) + + def forward(self, input): + return self.act_r(input.real).type(torch.complex64) + 1j * self.act_i(input.imag).type(torch.complex64) + + return ComplexAct( + act_type=act_type, + num_parameters=num_parameters, + negative_slope=negative_slope, + inplace=inplace, + *args, + **kwargs + ) + +def get_activation_fn( + act_type: Optional[str] = "relu", + num_parameters: Optional[int] = -1, + inplace: Optional[bool] = True, + negative_slope: Optional[float] = 0.1, + *args, + **kwargs +) -> nn.Module: + """ + Helper function to get activation (or non-linear) function + """ + return build_activation_layer( + act_type=act_type, + num_parameters=num_parameters, + negative_slope=negative_slope, + inplace=inplace, + *args, + **kwargs + ) diff --git a/Adaptive Frequency Filters/affnet/layers/normalization/__init__.py b/Adaptive Frequency Filters/affnet/layers/normalization/__init__.py new file mode 100644 index 0000000..0e9afd3 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/normalization/__init__.py @@ -0,0 +1,143 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +import os +import importlib +import argparse +from typing import Optional + +from utils import logger + +from ..identity import Identity + +SUPPORTED_NORM_FNS = [] +NORM_LAYER_REGISTRY = {} +NORM_LAYER_CLS = [] + + +def register_norm_fn(name): + def register_fn(cls): + if name in SUPPORTED_NORM_FNS: + raise ValueError( + "Cannot register duplicate normalization function ({})".format(name) + ) + SUPPORTED_NORM_FNS.append(name) + NORM_LAYER_REGISTRY[name] = cls + NORM_LAYER_CLS.append(cls) + return cls + + return register_fn + + +def build_normalization_layer( + opts, + num_features: int, + norm_type: Optional[str] = None, + num_groups: Optional[int] = None, + *args, + **kwargs +) -> torch.nn.Module: + """ + Helper function to build the normalization layer. + The function can be used in either of below mentioned ways: + Scenario 1: Set the default normalization layers using command line arguments. This is useful when the same normalization + layer is used for the entire network (e.g., ResNet). + Scenario 2: Network uses different normalization layers. In that case, we can override the default normalization + layer by specifying the name using `norm_type` argument + """ + norm_type = ( + getattr(opts, "model.normalization.name", "batch_norm") + if norm_type is None + else norm_type + ) + num_groups = ( + getattr(opts, "model.normalization.groups", 1) + if num_groups is None + else num_groups + ) + momentum = getattr(opts, "model.normalization.momentum", 0.1) + norm_layer = None + norm_type = norm_type.lower() if norm_type is not None else None + + if norm_type in NORM_LAYER_REGISTRY: + if torch.cuda.device_count() < 1 and norm_type.find("sync_batch") > -1: + # for a CPU-device, Sync-batch norm does not work. So, change to batch norm + norm_type = norm_type.replace("sync_", "") + norm_layer = NORM_LAYER_REGISTRY[norm_type]( + normalized_shape=num_features, + num_features=num_features, + momentum=momentum, + num_groups=num_groups, + ) + elif norm_type == "identity": + norm_layer = Identity() + else: + logger.error( + "Supported normalization layer arguments are: {}. Got: {}".format( + SUPPORTED_NORM_FNS, norm_type + ) + ) + return norm_layer + + +def arguments_norm_layers(parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Normalization layers", description="Normalization layers" + ) + + group.add_argument( + "--model.normalization.name", + default=None, + type=str, + help="Normalization layer. Defaults to None", + ) + group.add_argument( + "--model.normalization.groups", + default=1, + type=str, + help="Number of groups in group normalization layer. Defaults to 1.", + ) + group.add_argument( + "--model.normalization.momentum", + default=0.1, + type=float, + help="Momentum in normalization layers. Defaults to 0.1", + ) + + # Adjust momentum in batch norm layers + group.add_argument( + "--model.normalization.adjust-bn-momentum.enable", + action="store_true", + help="Adjust momentum in batch normalization layers", + ) + group.add_argument( + "--model.normalization.adjust-bn-momentum.anneal-type", + default="cosine", + type=str, + help="Method for annealing momentum in Batch normalization layer", + ) + group.add_argument( + "--model.normalization.adjust-bn-momentum.final-momentum-value", + default=1e-6, + type=float, + help="Min. momentum in batch normalization layer", + ) + + return parser + + +# automatically import different normalization layers +norm_dir = os.path.dirname(__file__) +for file in os.listdir(norm_dir): + path = os.path.join(norm_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("affnet.layers.normalization." + model_name) diff --git a/Adaptive Frequency Filters/affnet/layers/normalization/batch_norm.py b/Adaptive Frequency Filters/affnet/layers/normalization/batch_norm.py new file mode 100644 index 0000000..a8e88b3 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/normalization/batch_norm.py @@ -0,0 +1,173 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Optional, Tuple +import torch + +from . import register_norm_fn + + +@register_norm_fn(name="batch_norm") +@register_norm_fn(name="batch_norm_2d") +class BatchNorm2d(nn.BatchNorm2d): + """ + Applies a `Batch Normalization `_ over a 4D input tensor + + Args: + num_features (Optional, int): :math:`C` from an expected input of size :math:`(N, C, H, W)` + eps (Optional, float): Value added to the denominator for numerical stability. Default: 1e-5 + momentum (Optional, float): Value used for the running_mean and running_var computation. Default: 0.1 + affine (bool): If ``True``, use learnable affine parameters. Default: ``True`` + track_running_stats: If ``True``, tracks running mean and variance. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` where :math:`N` is the batch size, :math:`C` is the number of input channels, + :math:`H` is the input height, and :math:`W` is the input width + - Output: same shape as the input + """ + + def __init__( + self, + num_features: int, + eps: Optional[float] = 1e-5, + momentum: Optional[float] = 0.1, + affine: Optional[bool] = True, + track_running_stats: Optional[bool] = True, + *args, + **kwargs + ) -> None: + super().__init__( + num_features=num_features, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + ) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + # Since normalization layers can be fused, we do not count their operations + params = sum([p.numel() for p in self.parameters()]) + return input, params, 0.0 + + +@register_norm_fn(name="batch_norm_fp32") +class BatchNorm2dFP32(BatchNorm2d): + """ + Applies a `Batch Normalization `_ over a 4D input tensor in FP32 + """ + + def __init__( + self, + num_features: int, + eps: Optional[float] = 1e-5, + momentum: Optional[float] = 0.1, + affine: Optional[bool] = True, + track_running_stats: Optional[bool] = True, + *args, + **kwargs + ) -> None: + super().__init__( + num_features=num_features, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + *args, + **kwargs + ) + + def forward(self, input: Tensor) -> Tensor: + inp_dtype = input.dtype + return super().forward(input.to(torch.float32)).to(inp_dtype) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + # Since normalization layers can be fused, we do not count their operations + params = sum([p.numel() for p in self.parameters()]) + return input, params, 0.0 + + +@register_norm_fn(name="batch_norm_1d") +class BatchNorm1d(nn.BatchNorm1d): + """ + Applies a `Batch Normalization `_ over a 2D or 3D input tensor + + Args: + num_features (Optional, int): :math:`C` from an expected input of size :math:`(N, C)` or :math:`(N, C, L)` + eps (Optional, float): Value added to the denominator for numerical stability. Default: 1e-5 + momentum (Optional, float): Value used for the running_mean and running_var computation. Default: 0.1 + affine (bool): If ``True``, use learnable affine parameters. Default: ``True`` + track_running_stats: If ``True``, tracks running mean and variance. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)` where :math:`N` is the batch size, + :math:`C` is the number of input channels, and :math:`L` is the sequence length + - Output: same shape as the input + """ + + def __init__( + self, + num_features: int, + eps: Optional[float] = 1e-5, + momentum: Optional[float] = 0.1, + affine: Optional[bool] = True, + track_running_stats: Optional[bool] = True, + *args, + **kwargs + ) -> None: + super().__init__( + num_features=num_features, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + ) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + # Since normalization layers can be fused, we do not count their operations + params = sum([p.numel() for p in self.parameters()]) + return input, params, 0.0 + + +@register_norm_fn(name="batch_norm_3d") +class BatchNorm3d(nn.BatchNorm3d): + def __init__( + self, + num_features: int, + eps: Optional[float] = 1e-5, + momentum: Optional[float] = 0.1, + affine: Optional[bool] = True, + track_running_stats: Optional[bool] = True, + *args, + **kwargs + ) -> None: + """ + Applies a `Batch Normalization `_ over a 5D input tensor + + Args: + num_features (Optional, int): :math:`C` from an expected input of size :math:`(N, C, D, H, W)` + eps (Optional, float): Value added to the denominator for numerical stability. Default: 1e-5 + momentum (Optional, float): Value used for the running_mean and running_var computation. Default: 0.1 + affine (bool): If ``True``, use learnable affine parameters. Default: ``True`` + track_running_stats: If ``True``, tracks running mean and variance. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` where :math:`N` is the batch size, :math:`C` is the number of input + channels, :math:`D` is the input depth, :math:`H` is the input height, and :math:`W` is the input width + - Output: same shape as the input + """ + super().__init__( + num_features=num_features, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + ) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + # Since normalization layers can be fused, we do not count their operations + params = sum([p.numel() for p in self.parameters()]) + return input, params, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/normalization/group_norm.py b/Adaptive Frequency Filters/affnet/layers/normalization/group_norm.py new file mode 100644 index 0000000..7f41eaf --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/normalization/group_norm.py @@ -0,0 +1,50 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Optional, Tuple + +from . import register_norm_fn + + +@register_norm_fn(name="group_norm") +class GroupNorm(nn.GroupNorm): + """ + Applies a `Group Normalization `_ over an input tensor + + Args: + num_groups (int): number of groups to separate the input channels into + num_features (int): :math:`C` from an expected input of size :math:`(N, C, *)` + eps (Optional, float): Value added to the denominator for numerical stability. Default: 1e-5 + affine (bool): If ``True``, use learnable affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, *)` where :math:`N` is the batch size, :math:`C` is the number of input channels, + and :math:`*` is the remaining dimensions of the input tensor + - Output: same shape as the input + + .. note:: + GroupNorm is the same as LayerNorm when `num_groups=1` and it is the same as InstanceNorm when + `num_groups=C`. + """ + + def __init__( + self, + num_groups: int, + num_features: int, + eps: Optional[float] = 1e-5, + affine: Optional[bool] = True, + *args, + **kwargs + ) -> None: + super().__init__( + num_groups=num_groups, num_channels=num_features, eps=eps, affine=affine + ) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + # Since normalization layers can be fused, we do not count their operations + params = sum([p.numel() for p in self.parameters()]) + return input, params, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/normalization/instance_norm.py b/Adaptive Frequency Filters/affnet/layers/normalization/instance_norm.py new file mode 100644 index 0000000..066c3bd --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/normalization/instance_norm.py @@ -0,0 +1,95 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Optional, Tuple + +from . import register_norm_fn + + +@register_norm_fn(name="instance_norm") +@register_norm_fn(name="instance_norm_2d") +class InstanceNorm2d(nn.InstanceNorm2d): + """ + Applies a `Instance Normalization `_ over a 4D input tensor + + Args: + num_features (int): :math:`C` from an expected input of size :math:`(N, C, H, W)` + eps (Optional, float): Value added to the denominator for numerical stability. Default: 1e-5 + momentum (Optional, float): Value used for the running_mean and running_var computation. Default: 0.1 + affine (bool): If ``True``, use learnable affine parameters. Default: ``True`` + track_running_stats: If ``True``, tracks running mean and variance. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` where :math:`N` is the batch size, :math:`C` is the number of input channels, + :math:`H` is the input height, and :math:`W` is the input width + - Output: same shape as the input + """ + + def __init__( + self, + num_features: int, + eps: Optional[float] = 1e-5, + momentum: Optional[float] = 0.1, + affine: Optional[bool] = True, + track_running_stats: Optional[bool] = True, + *args, + **kwargs + ) -> None: + super().__init__( + num_features=num_features, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + ) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + # Since normalization layers can be fused, we do not count their operations + params = sum([p.numel() for p in self.parameters()]) + return input, params, 0.0 + + +@register_norm_fn(name="instance_norm_1d") +class InstanceNorm1d(nn.InstanceNorm1d): + """ + Applies a `Instance Normalization `_ over a 2D or 3D input tensor + + Args: + num_features (int): :math:`C` from an expected input of size :math:`(N, C)` or :math:`(N, C, L)` + eps (Optional, float): Value added to the denominator for numerical stability. Default: 1e-5 + momentum (Optional, float): Value used for the running_mean and running_var computation. Default: 0.1 + affine (bool): If ``True``, use learnable affine parameters. Default: ``True`` + track_running_stats: If ``True``, tracks running mean and variance. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)` where :math:`N` is the batch size, :math:`C` is the number + of input channels, and :math:`L` is the sequence length + - Output: same shape as the input + """ + + def __init__( + self, + num_features: int, + eps: Optional[float] = 1e-5, + momentum: Optional[float] = 0.1, + affine: Optional[bool] = True, + track_running_stats: Optional[bool] = True, + *args, + **kwargs + ) -> None: + super().__init__( + num_features=num_features, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + ) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + # Since normalization layers can be fused, we do not count their operations + params = sum([p.numel() for p in self.parameters()]) + return input, params, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/normalization/layer_norm.py b/Adaptive Frequency Filters/affnet/layers/normalization/layer_norm.py new file mode 100644 index 0000000..01927d2 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/normalization/layer_norm.py @@ -0,0 +1,145 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor, Size +from typing import Optional, Union, List +import torch + +from . import register_norm_fn + + +@register_norm_fn(name="layer_norm") +class LayerNorm(nn.LayerNorm): + """ + Applies `Layer Normalization `_ over a input tensor + + Args: + normalized_shape (int or list or torch.Size): input shape from an expected input + of size + + .. math:: + [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] + \times \ldots \times \text{normalized\_shape}[-1]] + + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (Optional, float): Value added to the denominator for numerical stability. Default: 1e-5 + elementwise_affine (bool): If ``True``, use learnable affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, *)` where :math:`N` is the batch size + - Output: same shape as the input + """ + + def __init__( + self, + normalized_shape: Union[int, List[int], Size], + eps: Optional[float] = 1e-5, + elementwise_affine: Optional[bool] = True, + *args, + **kwargs + ): + super().__init__( + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + ) + + def forward(self, x: Tensor) -> Tensor: + n_dim = x.ndim + if x.shape[1] == self.normalized_shape[0] and n_dim > 2: # channel-first format + s, u = torch.std_mean(x, dim=1, keepdim=True, unbiased=False) + x = (x - u) / (s + self.eps) + if self.weight is not None: + # Using fused operation for performing affine transformation: x = (x * weight) + bias + n_dim = x.ndim - 2 + new_shape = [1, self.normalized_shape[0]] + [1] * n_dim + x = torch.addcmul( + input=self.bias.reshape(*[new_shape]), + value=1.0, + tensor1=x, + tensor2=self.weight.reshape(*[new_shape]), + ) + return x + elif x.shape[-1] == self.normalized_shape[0]: # channel-last format + return super().forward(x) + else: + raise NotImplementedError( + "LayerNorm is supported for channel-first and channel-last format only" + ) + + def profile_module(self, input: Tensor) -> (Tensor, float, float): + params = sum([p.numel() for p in self.parameters()]) + return input, params, 0.0 + + +@register_norm_fn(name="layer_norm_2d") +@register_norm_fn(name="layer_norm_nchw") +class LayerNorm2D_NCHW(nn.GroupNorm): + """ + Applies `Layer Normalization `_ over a 4D input tensor + + Args: + num_features (int): :math:`C` from an expected input of size :math:`(N, C, H, W)` + eps (Optional, float): Value added to the denominator for numerical stability. Default: 1e-5 + elementwise_affine (bool): If ``True``, use learnable affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` where :math:`N` is the batch size, :math:`C` is the number of input channels, + :math:`H` is the input height, and :math:`W` is the input width + - Output: same shape as the input + """ + + def __init__( + self, + num_features: int, + eps: Optional[float] = 1e-5, + elementwise_affine: Optional[bool] = True, + *args, + **kwargs + ) -> None: + super().__init__( + num_channels=num_features, eps=eps, affine=elementwise_affine, num_groups=1 + ) + self.num_channels = num_features + + def __repr__(self): + return "{}(num_channels={}, eps={}, affine={})".format( + self.__class__.__name__, self.num_channels, self.eps, self.affine + ) + + def profile_module(self, input: Tensor) -> (Tensor, float, float): + params = sum([p.numel() for p in self.parameters()]) + return input, params, 0.0 + + +@register_norm_fn(name="layer_norm_fp32") +class LayerNormFP32(LayerNorm): + """ + Applies `Layer Normalization `_ over a input tensor with FP32 precision + """ + + def __init__( + self, + normalized_shape: Union[int, List[int], Size], + eps: Optional[float] = 1e-5, + elementwise_affine: Optional[bool] = True, + *args, + **kwargs + ): + super().__init__( + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + *args, + **kwargs + ) + + def forward(self, x: Tensor) -> Tensor: + # Convert input from dtype X to FP32 and perform normalization operation. + # This may help with underflow/overflow issues that we typically see with normalization layers + inp_dtype = x.dtype + return super().forward(x.to(torch.float32)).to(inp_dtype) diff --git a/Adaptive Frequency Filters/affnet/layers/normalization/sync_batch_norm.py b/Adaptive Frequency Filters/affnet/layers/normalization/sync_batch_norm.py new file mode 100644 index 0000000..72acdcc --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/normalization/sync_batch_norm.py @@ -0,0 +1,88 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +from typing import Optional, Tuple + +from . import register_norm_fn + + +@register_norm_fn(name="sync_batch_norm") +class SyncBatchNorm(nn.SyncBatchNorm): + """ + Applies a `Syncronized Batch Normalization `_ over the input tensor + + Args: + num_features (Optional, int): :math:`C` from an expected input of size :math:`(N, C, *)` + eps (Optional, float): Value added to the denominator for numerical stability. Default: 1e-5 + momentum (Optional, float): Value used for the running_mean and running_var computation. Default: 0.1 + affine (bool): If ``True``, use learnable affine parameters. Default: ``True`` + track_running_stats: If ``True``, tracks running mean and variance. Default: ``True`` + + Shape: + - Input: :math:`(N, C, *)` where :math:`N` is the batch size, :math:`C` is the number of input channels, + :math:`*` is the remaining input dimensions + - Output: same shape as the input + + """ + + def __init__( + self, + num_features: int, + eps: Optional[float] = 1e-5, + momentum: Optional[float] = 0.1, + affine: Optional[bool] = True, + track_running_stats: Optional[bool] = True, + *args, + **kwargs + ) -> None: + super().__init__( + num_features=num_features, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + ) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + # Since normalization layers can be fused, we do not count their operations + params = sum([p.numel() for p in self.parameters()]) + return input, params, 0.0 + + +@register_norm_fn(name="sync_batch_norm_fp32") +class SyncBatchNormFP32(SyncBatchNorm): + """ + Synchronized BN in FP32 + """ + + def __init__( + self, + num_features: int, + eps: Optional[float] = 1e-5, + momentum: Optional[float] = 0.1, + affine: Optional[bool] = True, + track_running_stats: Optional[bool] = True, + *args, + **kwargs + ) -> None: + super().__init__( + num_features=num_features, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + ) + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: + in_dtype = x.dtype + return super().forward(x.to(dtype=torch.float)).to(dtype=in_dtype) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + # Since normalization layers can be fused, we do not count their operations + params = sum([p.numel() for p in self.parameters()]) + return input, params, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/normalization_layers.py b/Adaptive Frequency Filters/affnet/layers/normalization_layers.py new file mode 100644 index 0000000..6a80e32 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/normalization_layers.py @@ -0,0 +1,148 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn +from typing import Optional +from utils import logger +import math + +from .normalization import build_normalization_layer, NORM_LAYER_CLS + +norm_layers_tuple = tuple(NORM_LAYER_CLS) + +def get_complex_normalization_layer( + opts, + num_features: int, + norm_type: Optional[str] = None, + num_groups: Optional[int] = None, + *args, + **kwargs +) -> nn.Module: + """ + Helper function to get normalization layers + """ + class ComplexNorm(nn.Module): + ''' + Naive approach to complex batch norm, perform batch norm independently on real and imaginary part. + ''' + def __init__(self, opts, num_features, norm_type, num_groups): + super(ComplexNorm, self).__init__() + self.bn_r = build_normalization_layer(opts, num_features, norm_type, num_groups) + self.bn_i = build_normalization_layer(opts, num_features, norm_type, num_groups) + + def forward(self, input): + return self.bn_r(input.real).type(torch.complex64) + 1j * self.bn_i(input.imag).type(torch.complex64) + + # return build_normalization_layer(opts, num_features, norm_type, num_groups) + return ComplexNorm(opts, num_features, norm_type, num_groups) + +def get_normalization_layer( + opts, + num_features: int, + norm_type: Optional[str] = None, + num_groups: Optional[int] = None, + *args, + **kwargs +) -> nn.Module: + """ + Helper function to get normalization layers + """ + return build_normalization_layer(opts, num_features, norm_type, num_groups) + + +class AdjustBatchNormMomentum(object): + """ + This class enables adjusting the momentum in batch normalization layer. + + .. note:: + It's an experimental feature and should be used with caution. + """ + + round_places = 6 + + def __init__(self, opts, *args, **kwargs): + self.is_iteration_based = getattr(opts, "scheduler.is_iteration_based", True) + self.warmup_iterations = getattr(opts, "scheduler.warmup_iterations", 10000) + + if self.is_iteration_based: + self.max_steps = getattr(opts, "scheduler.max_iterations", 100000) + self.max_steps -= self.warmup_iterations + assert self.max_steps > 0 + else: + logger.warning( + "Running {} for epoch-based methods. Not yet validation.".format( + self.__class__.__name__ + ) + ) + self.max_steps = getattr(opts, "scheduler.max_epochs", 100) + + self.momentum = getattr(opts, "model.normalization.momentum", 0.1) + self.min_momentum = getattr( + opts, "model.normalization.adjust_bn_momentum.final_momentum_value", 1e-6 + ) + + if self.min_momentum >= self.momentum: + logger.error( + "Min. momentum value in {} should be <= momentum. Got {} and {}".format( + self.__class__.__name__, self.min_momentum, self.momentum + ) + ) + + anneal_method = getattr( + opts, "model.normalization.adjust_bn_momentum.anneal_type", "cosine" + ) + if anneal_method is None: + logger.warning( + "Annealing method in {} is None. Setting to cosine".format( + self.__class__.__name__ + ) + ) + anneal_method = "cosine" + + anneal_method = anneal_method.lower() + + if anneal_method == "cosine": + self.anneal_fn = self._cosine + elif anneal_method == "linear": + self.anneal_fn = self._linear + else: + raise RuntimeError( + "Anneal method ({}) not yet implemented".format(anneal_method) + ) + self.anneal_method = anneal_method + + def _cosine(self, step: int) -> float: + curr_momentum = self.min_momentum + 0.5 * ( + self.momentum - self.min_momentum + ) * (1 + math.cos(math.pi * step / self.max_steps)) + + return round(curr_momentum, self.round_places) + + def _linear(self, step: int) -> float: + momentum_step = (self.momentum - self.min_momentum) / self.max_steps + curr_momentum = self.momentum - (step * momentum_step) + return round(curr_momentum, self.round_places) + + def adjust_momentum(self, model: nn.Module, iteration: int, epoch: int) -> None: + if iteration >= self.warmup_iterations: + step = ( + iteration - self.warmup_iterations if self.is_iteration_based else epoch + ) + curr_momentum = max(0.0, self.anneal_fn(step)) + + for m in model.modules(): + if isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)) and m.training: + m.momentum = curr_momentum + + def __repr__(self): + return "{}(iteration_based={}, inital_momentum={}, final_momentum={}, anneal_method={})".format( + self.__class__.__name__, + self.is_iteration_based, + self.momentum, + self.min_momentum, + self.anneal_method, + ) diff --git a/Adaptive Frequency Filters/affnet/layers/pixel_shuffle.py b/Adaptive Frequency Filters/affnet/layers/pixel_shuffle.py new file mode 100644 index 0000000..8f63037 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/pixel_shuffle.py @@ -0,0 +1,34 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Tuple + + +class PixelShuffle(nn.PixelShuffle): + """ + Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` + to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor. + + Args: + upscale_factor (int): factor to increase spatial resolution by + + Shape: + - Input: :math:`(*, C \times r^2, H, W)`, where * is zero or more dimensions + - Output: :math:`(*, C, H \times r, W \times r)` + """ + + def __init__(self, upscale_factor: int, *args, **kwargs) -> None: + super(PixelShuffle, self).__init__(upscale_factor=upscale_factor) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + input = self.forward(input) + return input, 0.0, 0.0 + + def __repr__(self): + return "{}(upscale_factor={})".format( + self.__class__.__name__, self.upscale_factor + ) diff --git a/Adaptive Frequency Filters/affnet/layers/pooling.py b/Adaptive Frequency Filters/affnet/layers/pooling.py new file mode 100644 index 0000000..eb54764 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/pooling.py @@ -0,0 +1,91 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Optional + + +class MaxPool2d(nn.MaxPool2d): + """ + Applies a 2D max pooling over a 4D input tensor. + + Args: + kernel_size (Optional[int]): the size of the window to take a max over + stride (Optional[int]): The stride of the window. Default: 2 + padding (Optional[int]): Padding to be added on both sides of the tensor. Default: 1 + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` where :math:`N` is the batch size, :math:`C` is the input channels, + :math:`H_{in}` is the input height, and :math:`W_{in}` is the input width + - Output: :math:`(N, C, H_{out}, W_{out})` where :math:`H_{out}` is the output height, and :math:`W_{in}` is + the output width + """ + + def __init__( + self, + kernel_size: Optional[int] = 3, + stride: Optional[int] = 2, + padding: Optional[int] = 1, + *args, + **kwargs + ) -> None: + super().__init__(kernel_size=kernel_size, stride=stride, padding=padding) + + def profile_module(self, input: Tensor) -> (Tensor, float, float): + input = self.forward(input) + return input, 0.0, 0.0 + + def __repr__(self): + return "{}(kernel_size={}, stride={})".format( + self.__class__.__name__, self.kernel_size, self.stride + ) + + +class AvgPool2d(nn.AvgPool2d): + """ + Applies a 2D average pooling over a 4D input tensor. + + Args: + kernel_size (Optional[int]): the size of the window to take a max over + stride (Optional[int]): The stride of the window. Default: 2 + padding (Optional[int]): Padding to be added on both sides of the tensor. Default: 1 + ceil_mode (Optional[bool]): When True, will use `ceil` instead of `floor` to compute the output shape. Default: False + count_include_pad (Optional[bool]): When True, will include the zero-padding in the averaging calculation. Default: True + divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used. Default: None + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` where :math:`N` is the batch size, :math:`C` is the input channels, + :math:`H_{in}` is the input height, and :math:`W_{in}` is the input width + - Output: :math:`(N, C, H_{out}, W_{out})` where :math:`H_{out}` is the output height, and :math:`W_{in}` is + the output width + """ + + def __init__( + self, + kernel_size: tuple, + stride: Optional[tuple] = None, + padding: Optional[tuple] = (0, 0), + ceil_mode: Optional[bool] = False, + count_include_pad: Optional[bool] = True, + divisor_override: Optional[bool] = None, + ): + super(AvgPool2d, self).__init__( + kernel_size=kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + divisor_override=divisor_override, + ) + + def profile_module(self, input: Tensor) -> (Tensor, float, float): + input = self.forward(input) + return input, 0.0, 0.0 + + def __repr__(self): + return "{}(upscale_factor={})".format( + self.__class__.__name__, self.upscale_factor + ) diff --git a/Adaptive Frequency Filters/affnet/layers/positional_embedding.py b/Adaptive Frequency Filters/affnet/layers/positional_embedding.py new file mode 100644 index 0000000..67b8d8e --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/positional_embedding.py @@ -0,0 +1,189 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +from torch.nn import functional as F +from typing import Optional +import math + +from . import BaseLayer + + +class PositionalEmbedding(BaseLayer): + def __init__( + self, + opts, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + is_learnable: Optional[bool] = False, + sequence_first: Optional[bool] = False, + interpolation_mode: Optional[str] = "bilinear", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + module = ( + LearnablePositionalEmbedding + if is_learnable + else SinusoidalPositionalEmbedding + ) + self.pos_embed = module( + opts, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + sequence_first=sequence_first, + interpolation_mode=interpolation_mode, + *args, + **kwargs + ) + + def forward(self, seq_len: int, *args, **kwargs) -> Tensor: + return self.pos_embed(seq_len, *args, **kwargs) + + def profile_module(self, input: Tensor, *args, **kwargs) -> (Tensor, float, float): + return input, 0.0, 0.0 + + def __repr__(self): + return self.pos_embed.__repr__() + + +class LearnablePositionalEmbedding(nn.Module): + """Learnable Positional embedding""" + + def __init__( + self, + opts, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + sequence_first: Optional[bool] = False, + interpolation_mode: Optional[str] = "bilinear", + *args, + **kwargs + ): + super().__init__() + self.pos_embed = nn.Parameter(torch.empty(1, 1, num_embeddings, embedding_dim)) + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + self.padding_idx = padding_idx + self.sequence_first = sequence_first + self.interpolation_mode = interpolation_mode + + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.trunc_normal_(self.pos_embed, mean=0, std=self.embedding_dim**-0.5) + if self.padding_idx is not None: + with torch.no_grad(): + self.pos_embed[:, :, self.padding_idx, ...] = 0.0 + + def profile_module(self, input: Tensor, *args, **kwargs) -> (Tensor, float, float): + return input, 0.0, 0.0 + + def forward(self, seq_len: int, *args, **kwargs) -> Tensor: + # scale pos embedding + pos_embed = self.pos_embed + if self.padding_idx is not None: + with torch.no_grad(): + pos_embed[:, :, self.padding_idx, ...] = 0.0 + + if seq_len != self.num_embeddings: + pos_embed = F.interpolate( + pos_embed, + size=(seq_len, self.embedding_dim), + mode=self.interpolation_mode, + ) + + # add dummy batch dimension + if self.sequence_first: + # Input is of the form [Seq_len, Batch, Embedding_dim] + return pos_embed.reshape(seq_len, 1, self.embedding_dim) + else: + # Input is of the form [Batch, Seq_len, Embedding_dim] + return pos_embed.reshape(1, seq_len, self.embedding_dim) + + def __repr__(self): + return "{}(num_embeddings={}, embedding_dim={}, padding_idx={}, sequence_first={})".format( + self.__class__.__name__, + self.num_embeddings, + self.embedding_dim, + self.padding_idx, + self.sequence_first, + ) + + +class SinusoidalPositionalEmbedding(nn.Module): + def __init__( + self, + opts, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + sequence_first: Optional[bool] = False, + interpolation_mode: Optional[str] = "bilinear", + *args, + **kwargs + ): + super().__init__() + self.padding_idx = padding_idx + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.sequence_first = sequence_first + self.interpolation_mode = interpolation_mode + self.register_buffer("pos_embed", self.get_weights()) + + def get_weights(self) -> Tensor: + """Build sinusoidal embeddings. Adapted from Fairseq.""" + half_dim = self.embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + emb = torch.arange(self.num_embeddings, dtype=torch.float).unsqueeze( + 1 + ) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).reshape( + self.num_embeddings, -1 + ) + if self.embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(self.num_embeddings, 1)], dim=1) + + # set embeddings corresponding to padding index to 0 + if self.padding_idx is not None: + emb[self.padding_idx, :] = 0 + return emb.unsqueeze(0).unsqueeze(0) + + def forward(self, seq_len: int, *args, **kwargs) -> Tensor: + # scale pos embedding + pos_embed = self.pos_embed + + if seq_len != self.num_embeddings: + pos_embed = F.interpolate( + pos_embed, + size=(seq_len, self.embedding_dim), + mode=self.interpolation_mode, + ) + + if self.sequence_first: + # Input is of the form [Seq_len, Batch, Embedding_dim] + return pos_embed.reshape(seq_len, 1, self.embedding_dim) + else: + # Input is of the form [Batch, Seq_len, Embedding_dim] + return pos_embed.reshape(1, seq_len, self.embedding_dim) + + def profile_module(self, input: Tensor, *args, **kwargs) -> (Tensor, float, float): + return input, 0.0, 0.0 + + def __repr__(self): + return "{}(num_embeddings={}, embedding_dim={}, padding_idx={}, sequence_first={})".format( + self.__class__.__name__, + self.num_embeddings, + self.embedding_dim, + self.padding_idx, + self.sequence_first, + ) diff --git a/Adaptive Frequency Filters/affnet/layers/positional_encoding.py b/Adaptive Frequency Filters/affnet/layers/positional_encoding.py new file mode 100644 index 0000000..9ccf3f3 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/positional_encoding.py @@ -0,0 +1,162 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +import math +from typing import Optional, Tuple + +from .base_layer import BaseLayer +from .dropout import Dropout + + +class SinusoidalPositionalEncoding(BaseLayer): + """ + This layer adds sinusoidal positional embeddings to a 3D input tensor. The code has been adapted from + `Pytorch tutorial `_ + + Args: + d_model (int): dimension of the input tensor + dropout (Optional[float]): Dropout rate. Default: 0.0 + max_len (Optional[int]): Max. number of patches (or seq. length). Default: 5000 + channels_last (Optional[bool]): Channels dimension is the last in the input tensor + + Shape: + - Input: :math:`(N, C, P)` or :math:`(N, P, C)` where :math:`N` is the batch size, :math:`C` is the embedding dimension, + :math:`P` is the number of patches + - Output: same shape as the input + + """ + + def __init__( + self, + d_model: int, + dropout: Optional[float] = 0.0, + max_len: Optional[int] = 5000, + channels_last: Optional[bool] = True, + *args, + **kwargs + ) -> None: + + position_last = not channels_last + + pos_encoding = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) + ) + pos_encoding[:, 0::2] = torch.sin(position * div_term) + pos_encoding[:, 1::2] = torch.cos(position * div_term) + # add dummy batch dimension + pos_encoding = pos_encoding.unsqueeze(0) # [1 x C x P_max) + + patch_dim = -2 # patch dimension is second last (N, P, C) + if position_last: + pos_encoding = pos_encoding.transpose( + 1, 2 + ) # patch dimension is last (N, C, P) + patch_dim = -1 # patch dimension is last (N, C, P) + + super().__init__() + + self.dropout = Dropout(p=dropout) + self.patch_dim = patch_dim + self.register_buffer("pe", pos_encoding) + + def forward_patch_last( + self, x, indices: Optional[Tensor] = None, *args, **kwargs + ) -> Tensor: + # seq_length should be the last dim + if indices is None: + x = x + self.pe[..., : x.shape[-1]] + else: + ndim = x.ndim + repeat_size = [x.shape[0]] + [-1] * (ndim - 1) + + pe = self.pe.expand(repeat_size) + selected_pe = torch.gather(pe, index=indices, dim=-1) + x = x + selected_pe + return self.dropout(x) + + def forward_others( + self, x, indices: Optional[Tensor] = None, *args, **kwargs + ) -> Tensor: + # seq_length should be the second last dim + if indices is None: + x = x + self.pe[..., : x.shape[-2], :] + else: + ndim = x.ndim + repeat_size = [x.shape[0]] + [-1] * (ndim - 1) + + pe = self.pe.expand(repeat_size) + selected_pe = torch.gather(pe, index=indices, dim=-2) + x = x + selected_pe + return self.dropout(x) + + def forward(self, x, indices: Optional[Tensor] = None, *args, **kwargs) -> Tensor: + if self.patch_dim == -1: + return self.forward_patch_last(x, indices=indices) + else: + return self.forward_others(x, indices=indices) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + return input, 0.0, 0.0 + + def __repr__(self): + return "{}(dropout={})".format(self.__class__.__name__, self.dropout.p) + + +class LearnablePositionEncoding(BaseLayer): + """ + This layer adds learnable positional embeddings to a 3D input tensor. + + Args: + embed_dim (int): dimension of the input tensor + num_embeddings (int): number of input embeddings. This is similar to vocab size in NLP. + dropout (Optional[float]): Dropout rate. Default: 0.0 + channels_last (Optional[bool]): Channels dimension is the last in the input tensor + + Shape: + - Input: :math:`(N, *, C, P)` or :math:`(N, *, P, C)` where :math:`N` is the batch size, :math:`C` is the embedding dimension, + :math:`P` is the number of patches + - Output: same shape as the input + + """ + + def __init__( + self, + embed_dim: int, + num_embeddings: int, + dropout: Optional[float] = 0.0, + channels_last: Optional[bool] = True, + *args, + **kwargs + ) -> None: + super().__init__() + self.pos_emb = nn.Embedding( + num_embeddings=num_embeddings, embedding_dim=embed_dim + ) + self.channel_last = channels_last + self.dropout = Dropout(p=dropout) + + def forward(self, x, *args, **kwargs) -> Tensor: + num_embeddings = x.shape[-2] if self.channel_last else x.shape[-1] + posistions = torch.arange(num_embeddings, dtype=torch.int64, device=x.device) + position_emb = self.pos_emb(posistions) + position_emb = position_emb.expand_as(x) + x = x + position_emb + return self.dropout(x) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + return input, 0.0, 0.0 + + def __repr__(self): + return "{}(embed_dim={}, vocab_size={}, dropout={})".format( + self.__class__.__name__, + self.pos_emb.embedding_dim, + self.pos_emb.num_embeddings, + self.dropout.p, + ) diff --git a/Adaptive Frequency Filters/affnet/layers/random_layers.py b/Adaptive Frequency Filters/affnet/layers/random_layers.py new file mode 100644 index 0000000..5bfba92 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/random_layers.py @@ -0,0 +1,61 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import Tensor +import random +from typing import List, Optional, Tuple + +from utils.math_utils import bound_fn + +from .base_layer import BaseLayer + + +class RandomApply(BaseLayer): + """ + This layer randomly applies a list of modules during training. + + Args: + module_list (List): List of modules + keep_p (Optional[float]): Keep P modules from the list during training. Default: 0.8 (or 80%) + """ + + def __init__( + self, module_list: List, keep_p: Optional[float] = 0.8, *args, **kwargs + ) -> None: + super().__init__() + n_modules = len(module_list) + self.module_list = module_list + + self.module_indexes = [i for i in range(1, n_modules)] + k = int(round(n_modules * keep_p)) + self.keep_k = bound_fn(min_val=1, max_val=n_modules, value=k) + + def forward(self, x: Tensor) -> Tensor: + if self.training: + indexes = [0] + sorted(random.sample(self.module_indexes, k=self.keep_k)) + for idx in indexes: + x = self.module_list[idx](x) + else: + for layer in self.module_list: + x = layer(x) + return x + + def profile_module(self, x, *args, **kwargs) -> Tuple[Tensor, float, float]: + params, macs = 0.0, 0.0 + for layer in self.module_list: + x, p, m = layer.profile_module(x) + params += p + macs += m + return x, params, macs + + def __repr__(self): + format_string = "{}(apply_k (N={})={}, ".format( + self.__class__.__name__, len(self.module_list), self.keep_k + ) + for layer in self.module_list: + format_string += "\n\t {}".format(layer) + format_string += "\n)" + return format_string diff --git a/Adaptive Frequency Filters/affnet/layers/single_head_attention.py b/Adaptive Frequency Filters/affnet/layers/single_head_attention.py new file mode 100644 index 0000000..2933a5f --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/single_head_attention.py @@ -0,0 +1,154 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +from typing import Tuple, Optional +from torch.nn import functional as F + +from .base_layer import BaseLayer +from .linear_layer import LinearLayer +from .dropout import Dropout +from ..misc.profiler import module_profile + + +class SingleHeadAttention(BaseLayer): + """ + This layer applies a single-head attention as described in `DeLighT `_ paper + + Args: + embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})` + attn_dropout (Optional[float]): Attention dropout. Default: 0.0 + bias (Optional[bool]): Use bias or not. Default: ``True`` + + Shape: + - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches, + and :math:`C_{in}` is input embedding dim + - Output: same shape as the input + + """ + + def __init__( + self, + embed_dim: int, + attn_dropout: Optional[float] = 0.0, + bias: Optional[bool] = True, + *args, + **kwargs + ) -> None: + super().__init__() + + self.qkv_proj = LinearLayer( + in_features=embed_dim, out_features=3 * embed_dim, bias=bias + ) + + self.attn_dropout = Dropout(p=attn_dropout) + self.out_proj = LinearLayer( + in_features=embed_dim, out_features=embed_dim, bias=bias + ) + + self.softmax = nn.Softmax(dim=-1) + self.embed_dim = embed_dim + self.scaling = self.embed_dim**-0.5 + + def __repr__(self) -> str: + return "{}(embed_dim={}, attn_dropout={})".format( + self.__class__.__name__, self.embed_dim, self.attn_dropout.p + ) + + def forward( + self, + x: Tensor, + x_kv: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + *args, + **kwargs + ) -> Tensor: + # [N, P, C] --> [N, P, 3C] + if x_kv is None: + qkv = self.qkv_proj(x) + # [N, P, 3C] --> [N, P, C] x 3 + query, key, value = torch.chunk(qkv, chunks=3, dim=-1) + else: + query = F.linear( + x, + weight=self.qkv_proj.weight[: self.embed_dim, ...], + bias=self.qkv_proj.bias[: self.embed_dim], + ) + + # [N, P, C] --> [N, P, 2C] + kv = F.linear( + x_kv, + weight=self.qkv_proj.weight[self.embed_dim :, ...], + bias=self.qkv_proj.bias[self.embed_dim :], + ) + key, value = torch.chunk(kv, chunks=2, dim=-1) + + query = query * self.scaling + + # [N, P, C] --> [N, C, P] + key = key.transpose(-2, -1) + + # QK^T + # [N, P, C] x [N, C, P] --> [N, P, P] + attn = torch.matmul(query, key) + + if attn_mask is not None: + # attn_mask shape should be the same as attn + assert list(attn_mask.shape) == list( + attn.shape + ), "Shape of attention mask and attn should be the same. Got: {} and {}".format( + attn_mask.shape, attn.shape + ) + attn = attn + attn_mask + + if key_padding_mask is not None: + # Do not attend to padding positions + # key padding mask size is [N, P] + batch_size, num_src_tokens, num_tgt_tokens = attn.shape + assert key_padding_mask.dim() == 2 and list(key_padding_mask.shape) == [ + batch_size, + num_tgt_tokens, + ], "Key_padding_mask should be 2-dimension with shape [{}, {}]. Got: {}".format( + batch_size, num_tgt_tokens, key_padding_mask.shape + ) + attn = attn.masked_fill( + key_padding_mask.unsqueeze(1).to(torch.bool), + float("-inf"), + ) + + attn = self.softmax(attn) + attn = self.attn_dropout(attn) + + # weighted sum + # [N, P, P] x [N, P, C] --> [N, P, C] + out = torch.matmul(attn, value) + out = self.out_proj(out) + + return out + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + b_sz, seq_len, in_channels = input.shape + params = macs = 0.0 + + qkv, p, m = module_profile(module=self.qkv_proj, x=input) + params += p + macs += m * seq_len * b_sz + + # number of operations in QK^T + m_qk = (seq_len * in_channels * in_channels) * b_sz + macs += m_qk + + # number of operations in computing weighted sum + m_wt = (seq_len * in_channels * in_channels) * b_sz + macs += m_wt + + out_p, p, m = module_profile(module=self.out_proj, x=input) + params += p + macs += m * seq_len * b_sz + + return input, params, macs diff --git a/Adaptive Frequency Filters/affnet/layers/softmax.py b/Adaptive Frequency Filters/affnet/layers/softmax.py new file mode 100644 index 0000000..c112c2f --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/softmax.py @@ -0,0 +1,27 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Optional, Tuple + + +class Softmax(nn.Softmax): + """ + Applies the Softmax function to an input tensor along the specified dimension + + Args: + dim (int): Dimension along which softmax to be applied. Default: -1 + + Shape: + - Input: :math:`(*)` where :math:`*` is one or more dimensions + - Output: same shape as the input + """ + + def __init__(self, dim: Optional[int] = -1, *args, **kwargs): + super().__init__(dim=dim) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + return input, 0.0, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/stocastic_depth.py b/Adaptive Frequency Filters/affnet/layers/stocastic_depth.py new file mode 100644 index 0000000..48b6455 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/stocastic_depth.py @@ -0,0 +1,23 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import Tensor +from typing import Tuple +from torchvision.ops import StochasticDepth as StochasticDepthTorch + + +class StochasticDepth(StochasticDepthTorch): + """ + Implements the Stochastic Depth `"Deep Networks with Stochastic Depth" + `_ used for randomly dropping residual + branches of residual architectures. + """ + + def __init__(self, p: float, mode: str) -> None: + super().__init__(p=p, mode=mode) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + return input, 0.0, 0.0 diff --git a/Adaptive Frequency Filters/affnet/layers/upsample.py b/Adaptive Frequency Filters/affnet/layers/upsample.py new file mode 100644 index 0000000..ccca072 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/layers/upsample.py @@ -0,0 +1,44 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +from typing import Optional, Tuple, Union + + +class UpSample(nn.Upsample): + """ + This layer upsamples a given input tensor. + + Args: + size (Optional[Union[int, Tuple[int, ...]]): Output spatial size. Default: None + scale_factor (Optional[float]): Scale each spatial dimension of the input by this factor. Default: None + mode (Optional[str]): Upsampling algorithm (``'nearest'``, ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``. Default: ``'nearest'`` + align_corners (Optional[bool]): if ``True``, the corner pixels of the input and output tensors are aligned, and thus preserving the values at + those pixels. This only has effect when :attr:`mode` is ``'linear'``, ``'bilinear'``, ``'bicubic'``, or ``'trilinear'``. + Default: ``None`` + + Shape: + - Input: :math:`(N, C, W_{in})` or :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C, W_{out})` or :math:`(N, C, H_{out}, W_{out})` or :math:`(N, C, D_{out}, H_{out}, W_{out})` + """ + + def __init__( + self, + size: Optional[Union[int, Tuple[int, ...]]] = None, + scale_factor: Optional[float] = None, + mode: Optional[str] = "nearest", + align_corners: Optional[bool] = None, + *args, + **kwargs + ) -> None: + super().__init__( + size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners + ) + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + input = self.forward(input) + return input, 0.0, 0.0 diff --git a/Adaptive Frequency Filters/affnet/matcher_det/__init__.py b/Adaptive Frequency Filters/affnet/matcher_det/__init__.py new file mode 100644 index 0000000..42e8fdf --- /dev/null +++ b/Adaptive Frequency Filters/affnet/matcher_det/__init__.py @@ -0,0 +1,79 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse +import os +import importlib + +from utils import logger +from utils.ddp_utils import is_master + +from .base_matcher import BaseMatcher + +# register BOX Matcher +MATCHER_REGISTRY = {} + + +def register_matcher(name): + def register_class(cls): + if name in MATCHER_REGISTRY: + raise ValueError("Cannot register duplicate matcher ({})".format(name)) + + if not issubclass(cls, BaseMatcher): + raise ValueError( + "Matcher ({}: {}) must extend BaseMatcher".format(name, cls.__name__) + ) + + MATCHER_REGISTRY[name] = cls + return cls + + return register_class + + +def arguments_box_matcher(parser: argparse.ArgumentParser): + group = parser.add_argument_group("Matcher", "Matcher") + group.add_argument( + "--matcher.name", + type=str, + help="Name of the matcher. Matcher matches anchors with GT box coordinates", + ) + + # add segmentation specific arguments + for k, v in MATCHER_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + + return parser + + +def build_matcher(opts, *args, **kwargs): + matcher_name = getattr(opts, "matcher.name", None) + matcher = None + if matcher_name in MATCHER_REGISTRY: + matcher = MATCHER_REGISTRY[matcher_name](opts, *args, **kwargs) + else: + supported_matchers = list(MATCHER_REGISTRY.keys()) + supp_matcher_str = "Got {} as matcher. Supported matchers are:".format( + matcher_name + ) + for i, m_name in enumerate(supported_matchers): + supp_matcher_str += "\n\t {}: {}".format(i, logger.color_text(m_name)) + + if is_master(opts): + logger.error(supp_matcher_str) + return matcher + + +# automatically import the matchers +matcher_dir = os.path.dirname(__file__) +for file in os.listdir(matcher_dir): + path = os.path.join(matcher_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + matcher_py = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("affnet.matcher_det." + matcher_py) diff --git a/Adaptive Frequency Filters/affnet/matcher_det/base_matcher.py b/Adaptive Frequency Filters/affnet/matcher_det/base_matcher.py new file mode 100644 index 0000000..4259a91 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/matcher_det/base_matcher.py @@ -0,0 +1,25 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse + + +class BaseMatcher(object): + """ + Base class for matching anchor boxes and labels for the task of object detection + """ + + def __init__(self, opts, *args, **kwargs) -> None: + super(BaseMatcher, self).__init__() + self.opts = opts + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + """Add class-specific arguments""" + return parser + + def __call__(self, *args, **kwargs): + raise NotImplementedError diff --git a/Adaptive Frequency Filters/affnet/matcher_det/ssd_matcher.py b/Adaptive Frequency Filters/affnet/matcher_det/ssd_matcher.py new file mode 100644 index 0000000..c46f98b --- /dev/null +++ b/Adaptive Frequency Filters/affnet/matcher_det/ssd_matcher.py @@ -0,0 +1,160 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import Tensor +import numpy as np +from typing import Optional, Union, Tuple +import argparse + +from utils import logger + +from . import BaseMatcher, register_matcher +from ..misc.third_party.ssd_utils import assign_priors +from ..misc.box_utils import ( + center_form_to_corner_form, + corner_form_to_center_form, + convert_boxes_to_locations, + convert_locations_to_boxes, +) + + +@register_matcher(name="ssd") +class SSDMatcher(BaseMatcher): + """ + This class assigns labels to anchors via `SSD matching process `_ + + Args: + opts: command line arguments + bg_class_id: Background class index + + Shape: + - Input: + - gt_boxes: Ground-truth boxes in corner form (xyxy format). Shape is :math:`(N, 4)` where :math:`N` is the number of boxes + - gt_labels: Ground-truth box labels. Shape is :math:`(N)` + - anchors: Anchor boxes in center form (c_x, c_y, w, h). Shape is :math:`(M, 4)` where :math:`M` is the number of anchors + + - Output: + - matched_boxes of shape :math:`(M, 4)` + - matched_box_labels of shape :math:`(M)` + """ + + def __init__(self, opts, bg_class_id: Optional[int] = 0, *args, **kwargs) -> None: + center_variance = getattr(opts, "matcher.ssd.center_variance", None) + check_variable(center_variance, "--matcher.ssd.center-variance") + + size_variance = getattr(opts, "matcher.ssd.size_variance", None) + check_variable(val=size_variance, args_str="--matcher.ssd.size-variance") + + iou_threshold = getattr(opts, "matcher.ssd.iou_threshold", None) + check_variable(val=iou_threshold, args_str="--matcher.ssd.iou-threshold") + + super().__init__(opts=opts, *args, **kwargs) + + self.center_variance = center_variance + self.size_variance = size_variance + self.iou_threshold = iou_threshold + self.bg_class_id = bg_class_id + + def __repr__(self): + return "{}(center_variance={}, size_variance={}, iou_threshold={})".format( + self.__class__.__name__, + self.center_variance, + self.size_variance, + self.iou_threshold, + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """ + Add SSD Matcher specific arguments + """ + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--matcher.ssd.center-variance", + type=float, + default=0.1, + help="Center variance for matching", + ) + group.add_argument( + "--matcher.ssd.size-variance", + type=float, + default=0.2, + help="Size variance.", + ) + group.add_argument( + "--matcher.ssd.iou-threshold", + type=float, + default=0.45, + help="IOU Threshold.", + ) + + return parser + + def __call__( + self, + gt_boxes: Union[np.ndarray, Tensor], + gt_labels: Union[np.ndarray, Tensor], + anchors: Tensor, + ) -> Tuple[Tensor, Tensor]: + if isinstance(gt_boxes, np.ndarray): + gt_boxes = torch.from_numpy(gt_boxes) + if isinstance(gt_labels, np.ndarray): + gt_labels = torch.from_numpy(gt_labels) + + # convert box priors from center [c_x, c_y] to corner_form [x, y] + anchors_xyxy = center_form_to_corner_form(boxes=anchors) + + matched_boxes_xyxy, matched_labels = assign_priors( + gt_boxes, # gt_boxes are in corner form [x, y, w, h] + gt_labels, + anchors_xyxy, # priors are in corner form [x, y, w, h] + self.iou_threshold, + background_id=self.bg_class_id, + ) + + # convert the matched boxes to center form [c_x, c_y] + matched_boxes_cxcywh = corner_form_to_center_form(matched_boxes_xyxy) + + # Eq.(2) in paper https://arxiv.org/pdf/1512.02325.pdf + boxes_for_regression = convert_boxes_to_locations( + gt_boxes=matched_boxes_cxcywh, # center form + prior_boxes=anchors, # center form + center_variance=self.center_variance, + size_variance=self.size_variance, + ) + + return boxes_for_regression, matched_labels + + def convert_to_boxes( + self, pred_locations: torch.Tensor, anchors: torch.Tensor + ) -> Tensor: + """ + Decodes boxes from predicted locations and anchors. + """ + + # decode boxes in center form + boxes = convert_locations_to_boxes( + pred_locations=pred_locations, + anchor_boxes=anchors, + center_variance=self.center_variance, + size_variance=self.size_variance, + ) + # convert boxes from center form [c_x, c_y] to corner form [x, y] + boxes = center_form_to_corner_form(boxes) + return boxes + + +def check_variable(val, args_str: str): + if val is None: + logger.error("{} cannot be None".format(args_str)) + + if not (0.0 < val < 1.0): + logger.error( + "The value of {} should be between 0 and 1. Got: {}".format(args_str, val) + ) diff --git a/Adaptive Frequency Filters/affnet/misc/__init__.py b/Adaptive Frequency Filters/affnet/misc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Adaptive Frequency Filters/affnet/misc/averaging_utils.py b/Adaptive Frequency Filters/affnet/misc/averaging_utils.py new file mode 100644 index 0000000..2a9f488 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/misc/averaging_utils.py @@ -0,0 +1,74 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn +from typing import Optional +from copy import deepcopy +import argparse + + +class EMA(object): + """ + For a given model, this class computes the exponential moving average of weights + + Args: + model (torch.nn.Module): Model + ema_momentum (Optional[float]): Momentum value shows the contribution of weights at current iteration. Default: 0.0005 + device (Optional[str]): Device (CPU or GPU) on which model resides. Default: cpu + """ + + def __init__( + self, + model: nn.Module, + ema_momentum: Optional[float] = 0.0005, + device: Optional[str] = "cpu", + *args, + **kwargs + ) -> None: + # make a deep copy of the model for accumulating moving average of parameters and set to eval mode + self.ema_model = deepcopy(model) + self.ema_model.eval() + self.momentum = ema_momentum + self.device = device + if device: + self.ema_model.to(device=device) + self.ema_has_module = hasattr(self.ema_model, "module") + for param in self.ema_model.parameters(): + param.requires_grad = False + + def update_parameters(self, model): + # correct a mismatch in state dict keys + has_module = hasattr(model, "module") and not self.ema_has_module + with torch.no_grad(): + msd = model.state_dict() + for k, ema_v in self.ema_model.state_dict().items(): + if has_module: + # .module is added if we use DistributedDataParallel or DataParallel wrappers around model + k = "module." + k + model_v = msd[k].detach() + if self.device: + model_v = model_v.to(device=self.device) + ema_v.copy_((ema_v * (1.0 - self.momentum)) + (self.momentum * model_v)) + + +def arguments_ema(parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="EMA", description="Exponential moving average arguments" + ) + group.add_argument( + "--ema.enable", action="store_true", help="Use exponential moving average" + ) + group.add_argument( + "--ema.momentum", type=float, default=0.0001, help="EMA momentum" + ) + group.add_argument( + "--ema.copy-at-epoch", + type=int, + default=-1, + help="Update model weights with EMA model at this epoch", + ) + return parser diff --git a/Adaptive Frequency Filters/affnet/misc/box_utils.py b/Adaptive Frequency Filters/affnet/misc/box_utils.py new file mode 100644 index 0000000..d35d511 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/misc/box_utils.py @@ -0,0 +1,119 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import Tensor + +""" + This file implements different conversion functions to implement `SSD `_ object detector. + Equations are written inside each function for brevity. +""" + + +def convert_locations_to_boxes( + pred_locations: Tensor, + anchor_boxes: Tensor, + center_variance: float, + size_variance: float, +) -> Tensor: + """ + This is an inverse of convert_boxes_to_locations function (or Eq.(2) in `SSD paper `_ + Args: + pred_locations (Tensor): predicted locations from detector + anchor_boxes (Tensor): prior boxes in center form + center_variance (float): variance value for centers (c_x and c_y) + size_variance (float): variance value for size (height and width) + + Returns: + predicted boxes tensor in center form + """ + + # priors can have one dimension less. + if anchor_boxes.dim() + 1 == pred_locations.dim(): + anchor_boxes = anchor_boxes.unsqueeze(0) + + # T_w = log(g_w/d_w) / size_variance ==> g_w = exp(T_w * size_variance) * d_w + # T_h = log(g_h/d_h) / size_variance ==> g_h = exp(T_h * size_variance) * d_h + pred_size = ( + torch.exp(pred_locations[..., 2:] * size_variance) * anchor_boxes[..., 2:] + ) + # T_cx = ((g_cx - d_cx) / d_w) / center_variance ==> g_cx = ((T_cx * center_variance) * d_w) + d_cx + # T_cy = ((g_cy - d_cy) / d_w) / center_variance ==> g_cy = ((T_cy * center_variance) * d_h) + d_cy + pred_center = ( + pred_locations[..., :2] * center_variance * anchor_boxes[..., 2:] + ) + anchor_boxes[..., :2] + + return torch.cat((pred_center, pred_size), dim=-1) + + +def convert_boxes_to_locations( + gt_boxes: Tensor, prior_boxes: Tensor, center_variance: float, size_variance: float +): + """ + This function implements Eq.(2) in the `SSD paper `_ + + Args: + gt_boxes (Tensor): Ground truth boxes in center form (cx, cy, w, h) + prior_boxes (Tensor): Prior boxes in center form (cx, cy, w, h) + center_variance (float): variance value for centers (c_x and c_y) + size_variance (float): variance value for size (height and width) + + Returns: + boxes tensor for training + """ + + # T_cx = ((g_cx - d_cx) / d_w) / center_variance; Center vairance is nothing but normalization + # T_cy = ((g_cy - d_cy) / d_h) / center_variance + # T_w = log(g_w/d_w) / size_variance and T_h = log(g_h/d_h) / size_varianc + + # priors can have one dimension less + if prior_boxes.dim() + 1 == gt_boxes.dim(): + prior_boxes = prior_boxes.unsqueeze(0) + + target_centers = ( + (gt_boxes[..., :2] - prior_boxes[..., :2]) / prior_boxes[..., 2:] + ) / center_variance + target_size = torch.log(gt_boxes[..., 2:] / prior_boxes[..., 2:]) / size_variance + return torch.cat((target_centers, target_size), dim=-1) + + +def center_form_to_corner_form(boxes: Tensor) -> Tensor: + """ + This function convert boxes from center to corner form + Args: + boxes (Tensor): Boxes in center form (cx,cy,w,h) + + Returns: + Boxes tensor in corner form (x,y,w,h) + """ + + # x = c_x - (delta_w * 0.5), y = c_y - (delta_h * 0.5) + # w = c_x + (delta_w * 0.5), h = c_y + (delta_h * 0.5) + return torch.cat( + ( + boxes[..., :2] - (boxes[..., 2:] * 0.5), + boxes[..., :2] + (boxes[..., 2:] * 0.5), + ), + dim=-1, + ) + + +def corner_form_to_center_form(boxes: torch.Tensor) -> torch.Tensor: + """ + This function converts boxes from corner to center form + Args: + boxes (Tensor): boxes in corner form (x, y, w, h) + + Returns: + Boxes tensor in center form (c_x, c_y, w, h) + """ + + # c_x = ( x + w ) * 0.5, c_y = (y + h) * 0.5 + # delta_w = w - x, delta_h = h - y + return torch.cat( + ((boxes[..., :2] + boxes[..., 2:]) * 0.5, boxes[..., 2:] - boxes[..., :2]), + dim=-1, + ) diff --git a/Adaptive Frequency Filters/affnet/misc/common.py b/Adaptive Frequency Filters/affnet/misc/common.py new file mode 100644 index 0000000..34b8bf3 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/misc/common.py @@ -0,0 +1,166 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +import os +import re +from typing import Any, Dict, List, Optional, Union + +from utils import logger +from utils.ddp_utils import is_start_rank_node + + +def clean_strip( + obj: Union[str, List[str]], sep: Optional[str] = ",", strip: bool = True +) -> List[str]: + # Allowing list of strings as input as well as comma-separated strings + if isinstance(obj, list): + strings = obj + else: + strings = obj.split(sep) + + if strip: + strings = [x.strip() for x in strings] + strings = [x for x in strings if x] + return strings + + +def load_pretrained_model( + model: torch.nn.Module, wt_loc: str, opts: Dict[str, Any], *args, **kwargs +) -> torch.nn.Module: + """ + Helper function to load pre-trained weights + """ + if not os.path.isfile(wt_loc): + logger.error("Pretrained file is not found here: {}".format(wt_loc)) + + wts = torch.load(wt_loc, map_location="cpu") + + is_master_node = is_start_rank_node(opts) + + exclude_scopes = getattr(opts, "model.resume_exclude_scopes", "") + exclude_scopes: List[str] = clean_strip(exclude_scopes) + + missing_scopes = getattr(opts, "model.ignore_missing_scopes", "") + missing_scopes: List[str] = clean_strip(missing_scopes) + + rename_scopes_map: List[List[str]] = getattr(opts, "model.rename_scopes_map", []) + if rename_scopes_map: + for entry in rename_scopes_map: + if len(entry) != 2: + raise ValueError( + "Every entry in model.rename_scopes_map must contain exactly two string elements" + " for before and after. Got {}.".format(str(entry)) + ) + + # By default, adding scopes that we exclude to missing scopes + # If you excluded something, you can't expect it to be there. + missing_scopes += exclude_scopes + + # remove unwanted scopes + if exclude_scopes: + for key in wts.copy(): + if any([re.match(x, key) for x in exclude_scopes]): + del wts[key] + + if rename_scopes_map: + for before, after in rename_scopes_map: + wts = {re.sub(before, after, key): value for key, value in wts.items()} + + strict = not bool(missing_scopes) + + try: + module = model.module if hasattr(model, "module") else model + missing_keys, unexpected_keys = module.load_state_dict(wts, strict=strict) + + if unexpected_keys: + raise Exception( + "Found unexpected keys: {}." + "You can ignore these keys using `model.resume_exclude_scopes`.".format( + ",".join(unexpected_keys) + ) + ) + + missing_keys = [ + key + for key in missing_keys + if not any([re.match(x, key) for x in missing_scopes]) + ] + + if missing_keys: + raise Exception( + "Missing keys detected. Did not find the following keys in pre-trained model: {}." + " You can ignore the keys using `model.ignore_missing_scopes`.".format( + ",".join(missing_keys) + ) + ) + + if is_master_node: + logger.log("Pretrained weights are loaded from {}".format(wt_loc)) + except Exception as e: + if is_master_node: + logger.error( + "Unable to load pretrained weights from {}. Error: {}".format(wt_loc, e) + ) + + return model + + +def parameter_list( + named_parameters, + weight_decay: Optional[float] = 0.0, + no_decay_bn_filter_bias: Optional[bool] = False, + *args, + **kwargs +): + module_name = kwargs.get("module_name", "") + with_decay = [] + without_decay = [] + with_decay_param_names = [] + without_decay_param_names = [] + if isinstance(named_parameters, list): + for n_parameter in named_parameters: + for p_name, param in n_parameter(): + if ( + param.requires_grad + and len(param.shape) == 1 + and no_decay_bn_filter_bias + ): + # biases and normalization layer parameters are of len 1 + without_decay.append(param) + without_decay_param_names.append(module_name + p_name) + elif param.requires_grad: + with_decay.append(param) + with_decay_param_names.append(module_name + p_name) + else: + for p_name, param in named_parameters(): + if ( + param.requires_grad + and len(param.shape) == 1 + and no_decay_bn_filter_bias + ): + # biases and normalization layer parameters are of len 1 + without_decay.append(param) + without_decay_param_names.append(module_name + p_name) + elif param.requires_grad: + with_decay.append(param) + with_decay_param_names.append(module_name + p_name) + param_list = [ + { + "params": with_decay, + "weight_decay": weight_decay, + "param_names": with_decay_param_names, + } + ] + if len(without_decay) > 0: + param_list.append( + { + "params": without_decay, + "weight_decay": 0.0, + "param_names": without_decay_param_names, + } + ) + return param_list diff --git a/Adaptive Frequency Filters/affnet/misc/init_utils.py b/Adaptive Frequency Filters/affnet/misc/init_utils.py new file mode 100644 index 0000000..aaaf9e3 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/misc/init_utils.py @@ -0,0 +1,151 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn +from typing import Optional + +from utils import logger + +from ..layers import LinearLayer, GroupLinear, norm_layers_tuple + +supported_conv_inits = [ + "kaiming_normal", + "kaiming_uniform", + "xavier_normal", + "xavier_uniform", + "normal", + "trunc_normal", +] +supported_fc_inits = [ + "kaiming_normal", + "kaiming_uniform", + "xavier_normal", + "xavier_uniform", + "normal", + "trunc_normal", +] + + +def _init_nn_layers( + module, + init_method: Optional[str] = "kaiming_normal", + std_val: Optional[float] = None, +) -> None: + """ + Helper function to initialize neural network module + """ + init_method = init_method.lower() + if init_method == "kaiming_normal": + if module.weight is not None: + nn.init.kaiming_normal_(module.weight, mode="fan_out") + if module.bias is not None: + nn.init.zeros_(module.bias) + elif init_method == "kaiming_uniform": + if module.weight is not None: + nn.init.kaiming_uniform_(module.weight, mode="fan_out") + if module.bias is not None: + nn.init.zeros_(module.bias) + elif init_method == "xavier_normal": + if module.weight is not None: + nn.init.xavier_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif init_method == "xavier_uniform": + if module.weight is not None: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif init_method == "normal": + if module.weight is not None: + std = 1.0 / module.weight.size(1) ** 0.5 if std_val is None else std_val + nn.init.normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif init_method == "trunc_normal": + if module.weight is not None: + std = 1.0 / module.weight.size(1) ** 0.5 if std_val is None else std_val + nn.init.trunc_normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + nn.init.zeros_(module.bias) + else: + supported_conv_message = "Supported initialization methods are:" + for i, l in enumerate(supported_conv_inits): + supported_conv_message += "\n \t {}) {}".format(i, l) + logger.error("{} \n Got: {}".format(supported_conv_message, init_method)) + + +def initialize_conv_layer( + module, + init_method: Optional[str] = "kaiming_normal", + std_val: Optional[float] = 0.01, +) -> None: + """Helper function to initialize convolution layers""" + _init_nn_layers(module=module, init_method=init_method, std_val=std_val) + + +def initialize_fc_layer( + module, init_method: Optional[str] = "normal", std_val: Optional[float] = 0.01 +) -> None: + """Helper function to initialize fully-connected layers""" + if hasattr(module, "layer"): + _init_nn_layers(module=module.layer, init_method=init_method, std_val=std_val) + else: + _init_nn_layers(module=module, init_method=init_method, std_val=std_val) + + +def initialize_norm_layers(module) -> None: + """Helper function to initialize normalization layers""" + + def _init_fn(module): + if hasattr(module, "weight") and module.weight is not None: + nn.init.ones_(module.weight) + if hasattr(module, "bias") and module.bias is not None: + nn.init.zeros_(module.bias) + + _init_fn(module.layer) if hasattr(module, "layer") else _init_fn(module=module) + + +def initialize_weights(opts, modules) -> None: + """Helper function to initialize differnet layers in a model""" + # weight initialization + conv_init_type = getattr(opts, "model.layer.conv_init", "kaiming_normal") + linear_init_type = getattr(opts, "model.layer.linear_init", "normal") + + conv_std = getattr(opts, "model.layer.conv_init_std_dev", None) + linear_std = getattr(opts, "model.layer.linear_init_std_dev", 0.01) + group_linear_std = getattr(opts, "model.layer.group_linear_init_std_dev", 0.01) + + if isinstance(modules, nn.Sequential): + for m in modules: + if isinstance(m, (nn.Conv2d, nn.Conv3d)): + initialize_conv_layer( + module=m, init_method=conv_init_type, std_val=conv_std + ) + elif isinstance(m, norm_layers_tuple): + initialize_norm_layers(module=m) + elif isinstance(m, (nn.Linear, LinearLayer)): + initialize_fc_layer( + module=m, init_method=linear_init_type, std_val=linear_std + ) + elif isinstance(m, GroupLinear): + initialize_fc_layer( + module=m, init_method=linear_init_type, std_val=group_linear_std + ) + else: + if isinstance(modules, (nn.Conv2d, nn.Conv3d)): + initialize_conv_layer( + module=modules, init_method=conv_init_type, std_val=conv_std + ) + elif isinstance(modules, norm_layers_tuple): + initialize_norm_layers(module=modules) + elif isinstance(modules, (nn.Linear, LinearLayer)): + initialize_fc_layer( + module=modules, init_method=linear_init_type, std_val=linear_std + ) + elif isinstance(modules, GroupLinear): + initialize_fc_layer( + module=modules, init_method=linear_init_type, std_val=group_linear_std + ) diff --git a/Adaptive Frequency Filters/affnet/misc/profiler.py b/Adaptive Frequency Filters/affnet/misc/profiler.py new file mode 100644 index 0000000..f2b1885 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/misc/profiler.py @@ -0,0 +1,33 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +from typing import Tuple + + +def module_profile(module, x: Tensor, *args, **kwargs) -> Tuple[Tensor, float, float]: + """ + Helper function to profile a module. + + .. note:: + Module profiling is for reference only and may contain errors as it solely relies on user implementation to + compute theoretical FLOPs + """ + + if isinstance(module, nn.Sequential): + n_macs = n_params = 0.0 + for l in module: + try: + x, l_p, l_macs = l.profile_module(x) + n_macs += l_macs + n_params += l_p + except Exception as e: + print(e, l) + pass + else: + x, n_params, n_macs = module.profile_module(x) + return x, n_params, n_macs diff --git a/Adaptive Frequency Filters/affnet/misc/third_party/__init__.py b/Adaptive Frequency Filters/affnet/misc/third_party/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Adaptive Frequency Filters/affnet/misc/third_party/ssd_utils.py b/Adaptive Frequency Filters/affnet/misc/third_party/ssd_utils.py new file mode 100644 index 0000000..c1102ed --- /dev/null +++ b/Adaptive Frequency Filters/affnet/misc/third_party/ssd_utils.py @@ -0,0 +1,126 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import Tensor +import math +from typing import Optional, Tuple + +""" +This source code in this file is adapted from following repos, both of which are released under MIT license. + Repository Link: + https://github.com/sacmehta/EdgeNets + https://github.com/qfgaohao/pytorch-ssd + File Link: + https://github.com/sacmehta/EdgeNets/blob/master/model/detection/match_priors.py +""" + + +def assign_priors( + gt_boxes: Tensor, + gt_labels: Tensor, + corner_form_priors: Tensor, + iou_threshold: float, + background_id: Optional[int] = 0, + *args, + **kwargs +) -> Tuple[Tensor, Tensor]: + """ + Assign ground truth boxes and targets to priors (or anchors) + + Args: + gt_boxes (Tensor): Ground-truth boxes tensor of shape (num_targets, 4) + gt_labels (Tensor): Ground-truth labels of shape (num_targets) + corner_form_priors (Tensor): Priors in corner form and has shape (num_priors, 4) + iou_threshold (float): Overlap between priors and gt_boxes. + background_id (int): Background class index. Default: 0 + + Returns: + boxes (Tensor): Boxes mapped to priors and has shape (num_priors, 4) + labels (Tensor): Labels for mapped boxes and has shape (num_priors) + """ + + if gt_labels.nelement() == 0: + # Images may not have any labels + dev = corner_form_priors.device + gt_boxes = torch.zeros((1, 4), dtype=torch.float32, device=dev) + gt_labels = torch.zeros(1, dtype=torch.int64, device=dev) + + ious = box_iou(gt_boxes.unsqueeze(0), corner_form_priors.unsqueeze(1)) + + # size: num_priors + best_target_per_prior, best_target_per_prior_index = ious.max(1) + # size: num_targets + best_prior_per_target, best_prior_per_target_index = ious.max(0) + + for target_index, prior_index in enumerate(best_prior_per_target_index): + best_target_per_prior_index[prior_index] = target_index + # 2.0 is used to make sure every target has a prior assigned + best_target_per_prior.index_fill_(0, best_prior_per_target_index, 2) + # size: num_priors + labels = gt_labels[best_target_per_prior_index] + labels[best_target_per_prior < iou_threshold] = background_id + boxes = gt_boxes[best_target_per_prior_index] + return boxes, labels + + +def box_iou( + boxes0: Tensor, boxes1: Tensor, eps: Optional[float] = 1e-5, *args, **kwargs +) -> Tensor: + """ + Computes intersection-over-union between two boxes + Args: + boxes0 (Tensor): Boxes 0 of shape (N, 4) + boxes1 (Tensor): Boxes 1 of shape (N or 1, 4) + eps (Optional[float]): A small value is added to denominator for numerical stability + + Returns: + iou (Tensor): IoU values between boxes0 and boxes1 and has shape (N) + """ + + def area_of(left_top, right_bottom) -> torch.Tensor: + """ + Given two corners of the rectangle, compute the area + Args: + left_top (N, 2): left top corner. + right_bottom (N, 2): right bottom corner. + Returns: + area (N): return the area. + """ + hw = torch.clamp(right_bottom - left_top, min=0.0) + return hw[..., 0] * hw[..., 1] + + overlap_left_top = torch.max(boxes0[..., :2], boxes1[..., :2]) + overlap_right_bottom = torch.min(boxes0[..., 2:], boxes1[..., 2:]) + + overlap_area = area_of(overlap_left_top, overlap_right_bottom) + area0 = area_of(boxes0[..., :2], boxes0[..., 2:]) + area1 = area_of(boxes1[..., :2], boxes1[..., 2:]) + return overlap_area / (area0 + area1 - overlap_area + eps) + + +def hard_negative_mining( + loss: Tensor, labels: Tensor, neg_pos_ratio: int, *args, **kwargs +) -> Tensor: + """ + This function is used to suppress the presence of a large number of negative predictions. For any example/image, + it keeps all the positive predictions and cut the number of negative predictions to make sure the ratio + between the negative examples and positive examples is no more than the given ratio for an image. + Args: + loss (Tensor): the loss for each example and has shape (N, num_priors). + labels (Tensor): the labels and has shape (N, num_priors). + neg_pos_ratio (int): the ratio between the negative examples and positive examples. Usually, it is set as 3. + + """ + pos_mask = labels > 0 + num_pos = pos_mask.long().sum(dim=1, keepdim=True) + num_neg = num_pos * neg_pos_ratio + + loss[pos_mask] = -math.inf + _, indexes = loss.sort(dim=1, descending=True) + _, orders = indexes.sort(dim=1) + neg_mask = orders < num_neg + return pos_mask | neg_mask diff --git a/Adaptive Frequency Filters/affnet/models/__init__.py b/Adaptive Frequency Filters/affnet/models/__init__.py new file mode 100644 index 0000000..61861af --- /dev/null +++ b/Adaptive Frequency Filters/affnet/models/__init__.py @@ -0,0 +1,109 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import importlib +import argparse +from utils import logger +import os + + +SUPPORTED_TASKS = [] +TASK_REGISTRY = {} +TASK_ARG_REGISTRY = {} + + +def register_tasks(name): + def register_task_class(cls): + if name in TASK_REGISTRY: + raise ValueError("Cannot register duplicate task ({})".format(name)) + + TASK_REGISTRY[name] = cls + SUPPORTED_TASKS.append(name) + return cls + + return register_task_class + + +def register_task_arguments(name): + def register_task_arg_fn(fn): + if name in TASK_ARG_REGISTRY: + raise ValueError( + "Cannot register duplicate task arguments ({})".format(name) + ) + + TASK_ARG_REGISTRY[name] = fn + return fn + + return register_task_arg_fn + + +def common_model_argumnets(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + # load model scopes + parser.add_argument( + "--model.resume-exclude-scopes", + type=str, + default="", + help="Comma-separated list of parameter scopes (regex strings) to exclude when loading a pre-trained model", + ) + parser.add_argument( + "--model.ignore-missing-scopes", + type=str, + default="", + help="Comma-separated list of parameter scopes (regex strings) to ignore if they are missing from the pre-training model", + ) + parser.add_argument( + "--model.rename-scopes-map", + type=list, + default=None, + help="A mapping from checkpoint variable names to match the existing model names." + " The mapping is represented as a List[List[str]], e.g. [['before', 'after'], ['this', 'that']]." + " Note: only loading from Yaml file is supported for this argument.", + ) + return parser + + +def arguments_model(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + # common arguments + parser = common_model_argumnets(parser=parser) + + for k, v in TASK_ARG_REGISTRY.items(): + parser = v(parser) + return parser + + +def get_model(opts, *args, **kwargs): + dataset_category = getattr(opts, "dataset.category", None) + if not dataset_category: + task_str = "Please specify dataset.category. Supported categories are:" + for i, task_name in enumerate(SUPPORTED_TASKS): + task_str += "\n\t {}: {}".format(i, task_name) + logger.error(task_str) + + dataset_category = dataset_category.lower() + + if dataset_category in TASK_REGISTRY: + return TASK_REGISTRY[dataset_category](opts, *args, **kwargs) + else: + task_str = ( + "Got {} as a task. Unfortunately, we do not support it yet." + "\nSupported tasks are:".format(dataset_category) + ) + for i, task_name in enumerate(SUPPORTED_TASKS): + task_str += "\n\t {}: {}".format(i, task_name) + logger.error(task_str) + + +# automatically import the tasks +tasks_dir = os.path.dirname(__file__) +for file in os.listdir(tasks_dir): + path = os.path.join(tasks_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + task_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("affnet.models." + task_name) diff --git a/Adaptive Frequency Filters/affnet/models/classification/__init__.py b/Adaptive Frequency Filters/affnet/models/classification/__init__.py new file mode 100644 index 0000000..1f84430 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/models/classification/__init__.py @@ -0,0 +1,146 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + + +import os +import importlib +import argparse + +from utils.download_utils import get_local_path +from utils import logger +from utils.common_utils import check_frozen_norm_layer +from utils.ddp_utils import is_master, is_start_rank_node + +from .. import register_tasks, register_task_arguments +from .base_cls import BaseEncoder +from ...misc.common import load_pretrained_model + +CLS_MODEL_REGISTRY = {} + + +def register_cls_models(name): + def register_model_class(cls): + if name in CLS_MODEL_REGISTRY: + raise ValueError("Cannot register duplicate model ({})".format(name)) + + if not issubclass(cls, BaseEncoder): + raise ValueError( + "Model ({}: {}) must extend BaseEncoder".format(name, cls.__name__) + ) + + CLS_MODEL_REGISTRY[name] = cls + return cls + + return register_model_class + + +@register_tasks(name="classification") +def build_classification_model(opts, *args, **kwargs): + model_name = getattr(opts, "model.classification.name", None) + model = None + is_master_node = is_master(opts) + if model_name in CLS_MODEL_REGISTRY: + cls_act_fn = getattr(opts, "model.classification.activation.name", None) + if cls_act_fn is not None: + # Override the general activation arguments + gen_act_fn = getattr(opts, "model.activation.name", "relu") + gen_act_inplace = getattr(opts, "model.activation.inplace", False) + gen_act_neg_slope = getattr(opts, "model.activation.neg_slope", 0.1) + + setattr(opts, "model.activation.name", cls_act_fn) + setattr( + opts, + "model.activation.inplace", + getattr(opts, "model.classification.activation.inplace", False), + ) + setattr( + opts, + "model.activation.neg_slope", + getattr(opts, "model.classification.activation.neg_slope", 0.1), + ) + + model = CLS_MODEL_REGISTRY[model_name](opts, *args, **kwargs) + + # Reset activation args + setattr(opts, "model.activation.name", gen_act_fn) + setattr(opts, "model.activation.inplace", gen_act_inplace) + setattr(opts, "model.activation.neg_slope", gen_act_neg_slope) + else: + model = CLS_MODEL_REGISTRY[model_name](opts, *args, **kwargs) + else: + supported_models = list(CLS_MODEL_REGISTRY.keys()) + supp_model_str = "Supported models are:" + for i, m_name in enumerate(supported_models): + supp_model_str += "\n\t {}: {}".format(i, logger.color_text(m_name)) + + if is_master_node: + logger.error(supp_model_str + "Got: {}".format(model_name)) + + finetune_task = getattr( + opts, "model.classification.finetune_pretrained_model", False + ) + pretrained = getattr(opts, "model.classification.pretrained", None) + if finetune_task: + n_pretrained_classes = getattr( + opts, "model.classification.n_pretrained_classes", None + ) + n_classes = getattr(opts, "model.classification.n_classes", None) + assert n_pretrained_classes is not None + assert n_classes is not None + + # The model structure is the same as pre-trained model now + model.update_classifier(opts, n_classes=n_pretrained_classes) + + # load the weights + if pretrained is not None: + pretrained = get_local_path(opts, path=pretrained) + model = load_pretrained_model(model=model, wt_loc=pretrained, opts=opts) + + # Now, re-initialize the classification layer + model.update_classifier(opts, n_classes=n_classes) + + elif pretrained is not None: + pretrained = get_local_path(opts, path=pretrained) + model = load_pretrained_model(model=model, wt_loc=pretrained, opts=opts) + + freeze_norm_layers = getattr(opts, "model.classification.freeze_batch_norm", False) + if freeze_norm_layers: + model.freeze_norm_layers() + frozen_state, count_norm = check_frozen_norm_layer(model) + if count_norm > 0 and frozen_state and is_master_node: + logger.error( + "Something is wrong while freezing normalization layers. Please check" + ) + + if is_master_node: + logger.log("Normalization layers are frozen") + + return model + + +@register_task_arguments(name="classification") +def arguments_classification(parser: argparse.ArgumentParser): + # add arguments (if any) specified in BaseEncoder class + parser = BaseEncoder.add_arguments(parser=parser) + + # add classification specific arguments + for k, v in CLS_MODEL_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + + return parser + + +# automatically import the models +models_dir = os.path.dirname(__file__) +for file in os.listdir(models_dir): + path = os.path.join(models_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("affnet.models.classification." + model_name) diff --git a/Adaptive Frequency Filters/affnet/models/classification/affnet.py b/Adaptive Frequency Filters/affnet/models/classification/affnet.py new file mode 100644 index 0000000..4ade19a --- /dev/null +++ b/Adaptive Frequency Filters/affnet/models/classification/affnet.py @@ -0,0 +1,304 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn +import argparse +from typing import Dict, Tuple, Optional + +from utils import logger + +from . import register_cls_models +from .base_cls import BaseEncoder +from .config.affnet import get_configuration +from ...layers import ConvLayer, LinearLayer, GlobalPool, Dropout, SeparableConv +from ...modules import InvertedResidual +from ...modules.aff_block import AFFBlock + + +@register_cls_models("affnet") +class AffNet(BaseEncoder): + """ + This class implements the `MobileViT architecture `_ + """ + + def __init__(self, opts, *args, **kwargs) -> None: + num_classes = getattr(opts, "model.classification.n_classes", 1000) + classifier_dropout = getattr( + opts, "model.classification.classifier_dropout", 0.0 + ) + + pool_type = getattr(opts, "model.layer.global_pool", "mean") + image_channels = 3 + out_channels = 16 + + mobilevit_config = get_configuration(opts=opts) + + super().__init__(opts, *args, **kwargs) + + # store model configuration in a dictionary + self.model_conf_dict = dict() + self.conv_1 = ConvLayer( + opts=opts, + in_channels=image_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + use_norm=True, + use_act=True, + ) + + self.model_conf_dict["conv1"] = {"in": image_channels, "out": out_channels} + + in_channels = out_channels + self.layer_1, out_channels = self._make_layer( + opts=opts, input_channel=in_channels, cfg=mobilevit_config["layer1"] + ) + self.model_conf_dict["layer1"] = {"in": in_channels, "out": out_channels} + + in_channels = out_channels + self.layer_2, out_channels = self._make_layer( + opts=opts, input_channel=in_channels, cfg=mobilevit_config["layer2"] + ) + self.model_conf_dict["layer2"] = {"in": in_channels, "out": out_channels} + + in_channels = out_channels + self.layer_3, out_channels = self._make_layer( + opts=opts, input_channel=in_channels, cfg=mobilevit_config["layer3"] + ) + self.model_conf_dict["layer3"] = {"in": in_channels, "out": out_channels} + + in_channels = out_channels + self.layer_4, out_channels = self._make_layer( + opts=opts, + input_channel=in_channels, + cfg=mobilevit_config["layer4"], + dilate=self.dilate_l4, + ) + self.model_conf_dict["layer4"] = {"in": in_channels, "out": out_channels} + + in_channels = out_channels + self.layer_5, out_channels = self._make_layer( + opts=opts, + input_channel=in_channels, + cfg=mobilevit_config["layer5"], + dilate=self.dilate_l5, + ) + self.model_conf_dict["layer5"] = {"in": in_channels, "out": out_channels} + + in_channels = out_channels + exp_channels = min(mobilevit_config["last_layer_exp_factor"] * in_channels, 960) + self.conv_1x1_exp = ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=exp_channels, + kernel_size=1, + stride=1, + use_act=True, + use_norm=True, + ) + + self.model_conf_dict["exp_before_cls"] = { + "in": in_channels, + "out": exp_channels, + } + + self.classifier = nn.Sequential() + self.classifier.add_module( + name="global_pool", module=GlobalPool(pool_type=pool_type, keep_dim=False) + ) + if 0.0 < classifier_dropout < 1.0: + self.classifier.add_module( + name="dropout", module=Dropout(p=classifier_dropout, inplace=True) + ) + self.classifier.add_module( + name="fc", + module=LinearLayer( + in_features=exp_channels, out_features=num_classes, bias=True + ), + ) + + # check model + self.check_model() + + # weight initialization + self.reset_parameters(opts=opts) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--model.classification.affnet.mode", + type=str, + default="small", + choices=["xx_small", "x_small", "small"], + help="MobileViT mode. Defaults to small", + ) + group.add_argument( + "--model.classification.affnet.attn-dropout", + type=float, + default=0.0, + help="Dropout in attention layer. Defaults to 0.0", + ) + group.add_argument( + "--model.classification.affnet.ffn-dropout", + type=float, + default=0.0, + help="Dropout between FFN layers. Defaults to 0.0", + ) + group.add_argument( + "--model.classification.affnet.dropout", + type=float, + default=0.0, + help="Dropout in Transformer layer. Defaults to 0.0", + ) + group.add_argument( + "--model.classification.affnet.attn-norm-layer", + type=str, + default="layer_norm", + help="Normalization layer in transformer. Defaults to LayerNorm", + ) + group.add_argument( + "--model.classification.affnet.no-fuse-local-global-features", + action="store_true", + help="Do not combine local and global features in MobileViT block", + ) + group.add_argument( + "--model.classification.affnet.conv-kernel-size", + type=int, + default=3, + help="Kernel size of Conv layers in MobileViT block", + ) + + group.add_argument( + "--model.classification.affnet.head-dim", + type=int, + default=None, + help="Head dimension in transformer", + ) + group.add_argument( + "--model.classification.affnet.number-heads", + type=int, + default=None, + help="Number of heads in transformer", + ) + return parser + + def _make_layer( + self, + opts, + input_channel, + cfg: Dict, + dilate: Optional[bool] = False, + *args, + **kwargs + ) -> Tuple[nn.Sequential, int]: + block_type = cfg.get("block_type", "mobilevit") + if block_type.lower() == "mobilevit": + return self._make_affnet_layer( + opts=opts, input_channel=input_channel, cfg=cfg, dilate=dilate + ) + else: + return self._make_mobilenet_layer( + opts=opts, input_channel=input_channel, cfg=cfg + ) + + @staticmethod + def _make_mobilenet_layer( + opts, input_channel: int, cfg: Dict, *args, **kwargs + ) -> Tuple[nn.Sequential, int]: + output_channels = cfg.get("out_channels") + num_blocks = cfg.get("num_blocks", 2) + expand_ratio = cfg.get("expand_ratio", 4) + block = [] + + for i in range(num_blocks): + stride = cfg.get("stride", 1) if i == 0 else 1 + + layer = InvertedResidual( + opts=opts, + in_channels=input_channel, + out_channels=output_channels, + stride=stride, + expand_ratio=expand_ratio, + ) + block.append(layer) + input_channel = output_channels + return nn.Sequential(*block), input_channel + + def _make_affnet_layer( + self, + opts, + input_channel, + cfg: Dict, + dilate: Optional[bool] = False, + *args, + **kwargs + ) -> Tuple[nn.Sequential, int]: + prev_dilation = self.dilation + block = [] + stride = cfg.get("stride", 1) + no_fuse = cfg.get("no_fuse", False) + + if stride == 2: + if dilate: + self.dilation *= 2 + stride = 1 + + layer = InvertedResidual( + opts=opts, + in_channels=input_channel, + out_channels=cfg.get("out_channels"), + stride=stride, + expand_ratio=cfg.get("mv_expand_ratio", 4), + dilation=prev_dilation, + ) + + block.append(layer) + input_channel = cfg.get("out_channels") + + head_dim = cfg.get("head_dim", 32) + transformer_dim = cfg["transformer_channels"] + ffn_dim = cfg.get("ffn_dim") + if head_dim is None: + num_heads = cfg.get("num_heads", 4) + if num_heads is None: + num_heads = 4 + head_dim = transformer_dim // num_heads + + if transformer_dim % head_dim != 0: + logger.error( + "Transformer input dimension should be divisible by head dimension. " + "Got {} and {}.".format(transformer_dim, head_dim) + ) + + block.append( + AFFBlock( + opts=opts, + in_channels=input_channel, + transformer_dim=transformer_dim, + ffn_dim=ffn_dim, + n_transformer_blocks=cfg.get("transformer_blocks", 1), + patch_h=cfg.get("patch_h", 2), + patch_w=cfg.get("patch_w", 2), + dropout=getattr(opts, "model.classification.affnet.dropout", 0.1), + ffn_dropout=getattr(opts, "model.classification.affnet.ffn_dropout", 0.0), + attn_dropout=getattr( + opts, "model.classification.affnet.attn_dropout", 0.1 + ), + head_dim=head_dim, + no_fusion=no_fuse, + conv_ksize=getattr( + opts, "model.classification.affnet.conv_kernel_size", 3 + ), + attn_norm_layer=getattr( + opts, "model.classification.affnet.attn_norm_layer", "layer_norm_2d" + ), + ) + ) + + return nn.Sequential(*block), input_channel diff --git a/Adaptive Frequency Filters/affnet/models/classification/base_cls.py b/Adaptive Frequency Filters/affnet/models/classification/base_cls.py new file mode 100644 index 0000000..16e5e42 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/models/classification/base_cls.py @@ -0,0 +1,459 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- +import torch +from torch import nn, Tensor +from torch.utils.checkpoint import checkpoint as gradient_checkpoint_fn +from typing import Optional, Dict, Tuple, Union, Any +import argparse + +from utils import logger + +from ... import parameter_list +from ...layers import norm_layers_tuple, LinearLayer +from ...misc.profiler import module_profile +from ...misc.init_utils import initialize_weights, initialize_fc_layer + +from ...neural_augmentor import build_neural_augmentor + + +class BaseEncoder(nn.Module): + """ + Base class for different classification models + """ + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__() + self.conv_1 = None + self.layer_1 = None + self.layer_2 = None + self.layer_3 = None + self.layer_4 = None + self.layer_5 = None + self.conv_1x1_exp = None + self.classifier = None + self.round_nearest = 8 + + # Segmentation architectures like Deeplab and PSPNet modifies the strides of the backbone + # We allow that using output_stride and replace_stride_with_dilation arguments + self.dilation = 1 + output_stride = kwargs.get("output_stride", None) + self.dilate_l4 = False + self.dilate_l5 = False + if output_stride == 8: + self.dilate_l4 = True + self.dilate_l5 = True + elif output_stride == 16: + self.dilate_l5 = True + + self.model_conf_dict = dict() + self.neural_augmentor = build_neural_augmentor(opts=opts, *args, **kwargs) + self.gradient_checkpointing = getattr( + opts, "model.classification.gradient_checkpointing", False + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + """Add model-specific arguments""" + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--model.classification.classifier-dropout", + type=float, + default=0.0, + help="Dropout rate in classifier", + ) + + group.add_argument( + "--model.classification.name", type=str, default=None, help="Model name" + ) + group.add_argument( + "--model.classification.n-classes", + type=int, + default=1000, + help="Number of classes in the dataset", + ) + group.add_argument( + "--model.classification.pretrained", + type=str, + default=None, + help="Path of the pretrained backbone", + ) + group.add_argument( + "--model.classification.freeze-batch-norm", + action="store_true", + help="Freeze batch norm layers", + ) + group.add_argument( + "--model.classification.activation.name", + default=None, + type=str, + help="Non-linear function name (e.g., relu)", + ) + group.add_argument( + "--model.classification.activation.inplace", + action="store_true", + help="Inplace non-linear functions", + ) + group.add_argument( + "--model.classification.activation.neg-slope", + default=0.1, + type=float, + help="Negative slope in leaky relu", + ) + + group.add_argument( + "--model.classification.finetune-pretrained-model", + action="store_true", + help="Finetune a pretrained model", + ) + group.add_argument( + "--model.classification.n-pretrained-classes", + type=int, + default=None, + help="Number of pre-trained classes", + ) + + group.add_argument( + "--model.classification.gradient-checkpointing", + action="store_true", + help="Checkpoint output of each spatial level in the classification backbone. Note that" + "we only take care of checkpointing in {}. If custom forward functions are used, please" + "implement checkpointing accordingly", + ) + + return parser + + def check_model(self): + assert ( + self.model_conf_dict + ), "Model configuration dictionary should not be empty" + assert self.conv_1 is not None, "Please implement self.conv_1" + assert self.layer_1 is not None, "Please implement self.layer_1" + assert self.layer_2 is not None, "Please implement self.layer_2" + assert self.layer_3 is not None, "Please implement self.layer_3" + assert self.layer_4 is not None, "Please implement self.layer_4" + assert self.layer_5 is not None, "Please implement self.layer_5" + assert self.conv_1x1_exp is not None, "Please implement self.conv_1x1_exp" + assert self.classifier is not None, "Please implement self.classifier" + + def reset_parameters(self, opts): + """Initialize model weights""" + initialize_weights(opts=opts, modules=self.modules()) + + def update_classifier(self, opts, n_classes: int) -> None: + """ + This function updates the classification layer in a model. Useful for finetuning purposes. + """ + linear_init_type = getattr(opts, "model.layer.linear_init", "normal") + if isinstance(self.classifier, nn.Sequential): + in_features = self.classifier[-1].in_features + layer = LinearLayer( + in_features=in_features, out_features=n_classes, bias=True + ) + initialize_fc_layer(layer, init_method=linear_init_type) + self.classifier[-1] = layer + else: + in_features = self.classifier.in_features + layer = LinearLayer( + in_features=in_features, out_features=n_classes, bias=True + ) + initialize_fc_layer(layer, init_method=linear_init_type) + + # re-init head + head_init_scale = 0.001 + layer.weight.data.mul_(head_init_scale) + layer.bias.data.mul_(head_init_scale) + + self.classifier = layer + + def _forward_layer(self, layer: nn.Module, x: Tensor) -> Tensor: + # Larger models with large input image size may not be able to fit into memory. + # We can use gradient checkpointing to enable training with large models and large inputs + return ( + gradient_checkpoint_fn(layer, x) + if self.gradient_checkpointing + else layer(x) + ) + + def extract_end_points_all( + self, + x: Tensor, + use_l5: Optional[bool] = True, + use_l5_exp: Optional[bool] = False, + *args, + **kwargs + ) -> Dict[str, Tensor]: + out_dict = {} # Use dictionary over NamedTuple so that JIT is happy + + if self.training and self.neural_augmentor is not None: + x = self.neural_augmentor(x) + out_dict["augmented_tensor"] = x + + x = self._forward_layer(self.conv_1, x) # 112 x112 + x = self._forward_layer(self.layer_1, x) # 112 x112 + out_dict["out_l1"] = x + + x = self._forward_layer(self.layer_2, x) # 56 x 56 + out_dict["out_l2"] = x + + x = self._forward_layer(self.layer_3, x) # 28 x 28 + out_dict["out_l3"] = x + + x = self._forward_layer(self.layer_4, x) # 14 x 14 + out_dict["out_l4"] = x + + if use_l5: + x = self._forward_layer(self.layer_5, x) # 7 x 7 + out_dict["out_l5"] = x + + if use_l5_exp: + x = self._forward_layer(self.conv_1x1_exp, x) + out_dict["out_l5_exp"] = x + return out_dict + + def extract_end_points_l4(self, x: Tensor, *args, **kwargs) -> Dict[str, Tensor]: + return self.extract_end_points_all(x, use_l5=False) + + def _extract_features(self, x: Tensor, *args, **kwargs) -> Tensor: + x = self._forward_layer(self.conv_1, x) + x = self._forward_layer(self.layer_1, x) + x = self._forward_layer(self.layer_2, x) + x = self._forward_layer(self.layer_3, x) + + x = self._forward_layer(self.layer_4, x) + x = self._forward_layer(self.layer_5, x) + x = self._forward_layer(self.conv_1x1_exp, x) + return x + + def _forward_classifier(self, x: Tensor, *args, **kwargs) -> Tensor: + # We add another classifier function so that the classifiers + # that do not adhere to the structure of BaseEncoder can still + # use neural augmentor + x = self._extract_features(x) + x = self.classifier(x) + return x + + def forward(self, x: Any, *args, **kwargs) -> Any: + if self.neural_augmentor is not None: + if self.training: + x_aug = self.neural_augmentor(x) + prediction = self._forward_classifier(x_aug) # .detach() + out_dict = {"augmented_tensor": x_aug, "logits": prediction} + else: + out_dict = { + "augmented_tensor": None, + "logits": self._forward_classifier(x), + } + return out_dict + else: + x = self._forward_classifier(x, *args, **kwargs) + return x + + def freeze_norm_layers(self) -> None: + """Freeze normalization layers""" + for m in self.modules(): + if isinstance(m, norm_layers_tuple): + m.eval() + m.weight.requires_grad = False + m.bias.requires_grad = False + m.training = False + + def get_trainable_parameters( + self, + weight_decay: Optional[float] = 0.0, + no_decay_bn_filter_bias: Optional[bool] = False, + *args, + **kwargs + ): + """Get trainable parameters""" + param_list = parameter_list( + named_parameters=self.named_parameters, + weight_decay=weight_decay, + no_decay_bn_filter_bias=no_decay_bn_filter_bias, + *args, + **kwargs + ) + return param_list, [1.0] * len(param_list) + + @staticmethod + def _profile_layers( + layers, input, overall_params, overall_macs, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + if not isinstance(layers, list): + layers = [layers] + + for layer in layers: + if layer is None: + continue + input, layer_param, layer_macs = module_profile(module=layer, x=input) + + overall_params += layer_param + overall_macs += layer_macs + + if isinstance(layer, nn.Sequential): + module_name = "\n+".join([l.__class__.__name__ for l in layer]) + else: + module_name = layer.__class__.__name__ + print( + "{:<15} \t {:<5}: {:>8.3f} M \t {:<5}: {:>8.3f} M".format( + module_name, + "Params", + round(layer_param / 1e6, 3), + "MACs", + round(layer_macs / 1e6, 3), + ) + ) + logger.singe_dash_line() + return input, overall_params, overall_macs + + def dummy_input_and_label(self, batch_size: int) -> Dict: + """Create dummy input and labels for CI/CD purposes. Child classes must override it + if functionality is different. + """ + img_channels = 3 + height = 224 + width = 224 + n_labels = 10 + img_tensor = torch.randn( + batch_size, img_channels, height, width, dtype=torch.float + ) + label_tensor = torch.randint(low=0, high=n_labels, size=(batch_size,)).long() + return {"samples": img_tensor, "targets": label_tensor} + + def profile_model( + self, input: Tensor, is_classification: Optional[bool] = True, *args, **kwargs + ) -> Tuple[Union[Tensor, Dict[str, Tensor]], float, float]: + """ + Helper function to profile a model. + + .. note:: + Model profiling is for reference only and may contain errors as it solely relies on user implementation to + compute theoretical FLOPs + """ + overall_params, overall_macs = 0.0, 0.0 + + input_fvcore = input.clone() + + if is_classification: + logger.log("Model statistics for an input of size {}".format(input.size())) + logger.double_dash_line(dashes=65) + print("{:>35} Summary".format(self.__class__.__name__)) + logger.double_dash_line(dashes=65) + + out_dict = {} + input, overall_params, overall_macs = self._profile_layers( + [self.conv_1, self.layer_1], + input=input, + overall_params=overall_params, + overall_macs=overall_macs, + ) + out_dict["out_l1"] = input + + input, overall_params, overall_macs = self._profile_layers( + self.layer_2, + input=input, + overall_params=overall_params, + overall_macs=overall_macs, + ) + out_dict["out_l2"] = input + + input, overall_params, overall_macs = self._profile_layers( + self.layer_3, + input=input, + overall_params=overall_params, + overall_macs=overall_macs, + ) + out_dict["out_l3"] = input + + input, overall_params, overall_macs = self._profile_layers( + self.layer_4, + input=input, + overall_params=overall_params, + overall_macs=overall_macs, + ) + out_dict["out_l4"] = input + + input, overall_params, overall_macs = self._profile_layers( + self.layer_5, + input=input, + overall_params=overall_params, + overall_macs=overall_macs, + ) + out_dict["out_l5"] = input + + if self.conv_1x1_exp is not None: + input, overall_params, overall_macs = self._profile_layers( + self.conv_1x1_exp, + input=input, + overall_params=overall_params, + overall_macs=overall_macs, + ) + out_dict["out_l5_exp"] = input + + if is_classification: + classifier_params, classifier_macs = 0.0, 0.0 + if self.classifier is not None: + input, classifier_params, classifier_macs = module_profile( + module=self.classifier, x=input + ) + print( + "{:<15} \t {:<5}: {:>8.3f} M \t {:<5}: {:>8.3f} M".format( + "Classifier", + "Params", + round(classifier_params / 1e6, 3), + "MACs", + round(classifier_macs / 1e6, 3), + ) + ) + overall_params += classifier_params + overall_macs += classifier_macs + + logger.double_dash_line(dashes=65) + print( + "{:<20} = {:>8.3f} M".format("Overall parameters", overall_params / 1e6) + ) + overall_params_py = sum([p.numel() for p in self.parameters()]) + print( + "{:<20} = {:>8.3f} M".format( + "Overall parameters (sanity check)", overall_params_py / 1e6 + ) + ) + + # Counting Addition and Multiplication as 1 operation + print( + "{:<20} = {:>8.3f} M".format( + "Overall MACs (theoretical)", overall_macs / 1e6 + ) + ) + + # compute flops using FVCore + try: + # compute flops using FVCore also + from fvcore.nn import FlopCountAnalysis + + flop_analyzer = FlopCountAnalysis(self.eval(), input_fvcore) + flop_analyzer.unsupported_ops_warnings(False) + flop_analyzer.uncalled_modules_warnings(False) + flops_fvcore = flop_analyzer.total() + + print( + "{:<20} = {:>8.3f} M".format( + "Overall MACs (FVCore)**", flops_fvcore / 1e6 + ) + ) + print( + "\n** Theoretical and FVCore MACs may vary as theoretical MACs do not account " + "for certain operations which may or may not be accounted in FVCore" + ) + except Exception: + pass + + print("Note: Theoretical MACs depends on user-implementation. Be cautious") + logger.double_dash_line(dashes=65) + + return out_dict, overall_params, overall_macs diff --git a/Adaptive Frequency Filters/affnet/models/classification/config/__init__.py b/Adaptive Frequency Filters/affnet/models/classification/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Adaptive Frequency Filters/affnet/models/classification/config/affnet.py b/Adaptive Frequency Filters/affnet/models/classification/config/affnet.py new file mode 100644 index 0000000..0fbd296 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/models/classification/config/affnet.py @@ -0,0 +1,329 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from typing import Dict + +from utils import logger + + +def get_configuration(opts) -> Dict: + mode = getattr(opts, "model.classification.affnet.mode", "small") + if mode is None: + logger.error("Please specify mode") + + head_dim = getattr(opts, "model.classification.affnet.head_dim", None) + num_heads = getattr(opts, "model.classification.affnet.number_heads", 4) + if head_dim is not None: + if num_heads is not None: + logger.error( + "--model.classification.affnet.head-dim and --model.classification.affnet.number-heads " + "are mutually exclusive." + ) + elif num_heads is not None: + if head_dim is not None: + logger.error( + "--model.classification.affnet.head-dim and --model.classification.affnet.number-heads " + "are mutually exclusive." + ) + mode = mode.lower() + if mode == "xx_small": + mv2_exp_mult = 2 + config = { + "layer1": { + "out_channels": 32, + "expand_ratio": mv2_exp_mult, + "num_blocks": 1, + "stride": 1, + "block_type": "mv2", + }, + "layer2": { + "out_channels": 48, + "expand_ratio": mv2_exp_mult, + "num_blocks": 3, + "stride": 2, + "block_type": "mv2", + }, + "layer3": { # 28x28 + "out_channels": 64, + "transformer_channels": 64, + "ffn_dim": 128, + "transformer_blocks": 2, + "patch_h": 2, # 8, + "patch_w": 2, # 8, + "stride": 2, + "mv_expand_ratio": mv2_exp_mult, + "head_dim": head_dim, + "num_heads": num_heads, + "block_type": "mobilevit", + }, + "layer4": { # 14x14 + "out_channels": 104, + "transformer_channels": 104, + "ffn_dim": 208, + "transformer_blocks": 4, + "patch_h": 2, # 4, + "patch_w": 2, # 4, + "stride": 2, + "mv_expand_ratio": mv2_exp_mult, + "head_dim": head_dim, + "num_heads": num_heads, + "block_type": "mobilevit", + }, + "layer5": { # 7x7 + "out_channels": 144, + "transformer_channels": 144, + "ffn_dim": 288, + "transformer_blocks": 3, + "patch_h": 2, + "patch_w": 2, + "stride": 2, + "mv_expand_ratio": mv2_exp_mult, + "head_dim": head_dim, + "num_heads": num_heads, + "block_type": "mobilevit", + }, + "last_layer_exp_factor": 4, + } + elif mode == "x_small": + mv2_exp_mult = 4 + config = { + "layer1": { + "out_channels": 32, + "expand_ratio": mv2_exp_mult, + "num_blocks": 1, + "stride": 1, + "block_type": "mv2", + }, + "layer2": { + "out_channels": 48, + "expand_ratio": mv2_exp_mult, + "num_blocks": 3, + "stride": 2, + "block_type": "mv2", + }, + "layer3": { # 28x28 + "out_channels": 96, + "transformer_channels": 96, + "ffn_dim": 192, + "transformer_blocks": 2, + "patch_h": 2, + "patch_w": 2, + "stride": 2, + "mv_expand_ratio": mv2_exp_mult, + "head_dim": head_dim, + "num_heads": num_heads, + "block_type": "mobilevit", + }, + "layer4": { # 14x14 + "out_channels": 160, + "transformer_channels": 160, + "ffn_dim": 320, + "transformer_blocks": 4, + "patch_h": 2, + "patch_w": 2, + "stride": 2, + "mv_expand_ratio": mv2_exp_mult, + "head_dim": head_dim, + "num_heads": num_heads, + "block_type": "mobilevit", + }, + "layer5": { # 7x7 + "out_channels": 192, + "transformer_channels": 192, + "ffn_dim": 384, + "transformer_blocks": 3, + "patch_h": 2, + "patch_w": 2, + "stride": 2, + "mv_expand_ratio": mv2_exp_mult, + "head_dim": head_dim, + "num_heads": num_heads, + "block_type": "mobilevit", + }, + "last_layer_exp_factor": 4, + } + elif mode == "small": + mv2_exp_mult = 4 + config = { + "layer1": { + "out_channels": 32, + "expand_ratio": mv2_exp_mult, + "num_blocks": 1, + "stride": 1, + "block_type": "mv2", + }, + "layer2": { + "out_channels": 64, + "expand_ratio": mv2_exp_mult, + "num_blocks": 3, + "stride": 2, + "block_type": "mv2", + }, + "layer3": { # 28x28 + "out_channels": 128, + "transformer_channels": 128, + "ffn_dim": 256, + "transformer_blocks": 2, + "patch_h": 2, + "patch_w": 2, + "stride": 2, + "mv_expand_ratio": mv2_exp_mult, + "head_dim": head_dim, + "num_heads": num_heads, + "block_type": "mobilevit", + }, + "layer4": { # 14x14 + "out_channels": 256, + "transformer_channels": 256, + "ffn_dim": 512, + "transformer_blocks": 4, + "patch_h": 2, + "patch_w": 2, + "stride": 2, + "mv_expand_ratio": mv2_exp_mult, + "head_dim": head_dim, + "num_heads": num_heads, + "block_type": "mobilevit", + }, + "layer5": { # 7x7 + "out_channels": 320, + "transformer_channels": 320, + "ffn_dim": 640, + "transformer_blocks": 3, + "patch_h": 2, + "patch_w": 2, + "stride": 2, + "mv_expand_ratio": mv2_exp_mult, + "head_dim": head_dim, + "num_heads": num_heads, + "block_type": "mobilevit", + }, + "last_layer_exp_factor": 4, + } + elif mode == "base": + mv2_exp_mult = 4 + config = { + "layer1": { + "out_channels": 64, + "expand_ratio": mv2_exp_mult, + "num_blocks": 1, + "stride": 1, + "block_type": "mv2", + }, + "layer2": { + "out_channels": 128, + "expand_ratio": mv2_exp_mult, + "num_blocks": 3, + "stride": 2, + "block_type": "mv2", + }, + "layer3": { # 28x28 + "out_channels": 256, + "transformer_channels": 256, + "ffn_dim": 512, + "transformer_blocks": 2, + "patch_h": 2, + "patch_w": 2, + "stride": 2, + "mv_expand_ratio": mv2_exp_mult, + "head_dim": head_dim, + "num_heads": num_heads, + "block_type": "mobilevit", + }, + "layer4": { # 14x14 + "out_channels": 512, + "transformer_channels": 512, + "ffn_dim": 1024, + "transformer_blocks": 4, + "patch_h": 2, + "patch_w": 2, + "stride": 2, + "mv_expand_ratio": mv2_exp_mult, + "head_dim": head_dim, + "num_heads": num_heads, + "block_type": "mobilevit", + "no_fuse": True + }, + "layer5": { # 7x7 + "out_channels": 640, + "transformer_channels": 640, + "ffn_dim": 1280, + "transformer_blocks": 3, + "patch_h": 2, + "patch_w": 2, + "stride": 2, + "mv_expand_ratio": mv2_exp_mult, + "head_dim": head_dim, + "num_heads": num_heads, + "block_type": "mobilevit", + "no_fuse": True + }, + "last_layer_exp_factor": 4, + } + elif mode == "large": + mv2_exp_mult = 4 + config = { + "layer1": { + "out_channels": 64, + "expand_ratio": mv2_exp_mult, + "num_blocks": 2, + "stride": 1, + "block_type": "mv2", + }, + "layer2": { + "out_channels": 128, + "expand_ratio": mv2_exp_mult, + "num_blocks": 6, + "stride": 2, + "block_type": "mv2", + }, + "layer3": { # 28x28 + "out_channels": 256, + "transformer_channels": 256, + "ffn_dim": 512, + "transformer_blocks": 4, + "patch_h": 2, + "patch_w": 2, + "stride": 2, + "mv_expand_ratio": mv2_exp_mult, + "head_dim": head_dim, + "num_heads": num_heads, + "block_type": "mobilevit", + }, + "layer4": { # 14x14 + "out_channels": 512, + "transformer_channels": 512, + "ffn_dim": 1024, + "transformer_blocks": 18, + "patch_h": 2, + "patch_w": 2, + "stride": 2, + "mv_expand_ratio": mv2_exp_mult, + "head_dim": head_dim, + "num_heads": num_heads, + "block_type": "mobilevit", + "no_fuse": True + }, + "layer5": { # 7x7 + "out_channels": 768, + "transformer_channels": 768, + "ffn_dim": 1536, + "transformer_blocks": 6, + "patch_h": 2, + "patch_w": 2, + "stride": 2, + "mv_expand_ratio": mv2_exp_mult, + "head_dim": head_dim, + "num_heads": num_heads, + "block_type": "mobilevit", + "no_fuse": True + }, + "last_layer_exp_factor": 4, + } + else: + raise NotImplementedError + + return config diff --git a/Adaptive Frequency Filters/affnet/models/detection/__init__.py b/Adaptive Frequency Filters/affnet/models/detection/__init__.py new file mode 100644 index 0000000..d60f6b5 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/models/detection/__init__.py @@ -0,0 +1,139 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from .base_detection import BaseDetection +import os +import importlib +import argparse + +from utils.download_utils import get_local_path +from utils import logger +from utils.ddp_utils import is_master, is_start_rank_node +from utils.common_utils import check_frozen_norm_layer + +from .. import register_tasks, register_task_arguments +from ...misc.common import load_pretrained_model +from ...models.classification import build_classification_model + + +DETECT_MODEL_REGISTRY = {} + + +def register_detection_models(name): + def register_model_class(cls): + if name in DETECT_MODEL_REGISTRY: + raise ValueError("Cannot register duplicate model ({})".format(name)) + + if not issubclass(cls, BaseDetection): + raise ValueError( + "Model ({}: {}) must extend BaseDetection".format(name, cls.__name__) + ) + + DETECT_MODEL_REGISTRY[name] = cls + return cls + + return register_model_class + + +@register_tasks(name="detection") +def build_detection_model(opts): + det_model_name = getattr(opts, "model.detection.name", None) + model = None + is_master_node = is_master(opts) + if det_model_name in DETECT_MODEL_REGISTRY: + output_stride = getattr(opts, "model.detection.output_stride", None) + encoder = build_classification_model(opts=opts, output_stride=output_stride) + model = DETECT_MODEL_REGISTRY[det_model_name](opts, encoder) + else: + supported_models = list(DETECT_MODEL_REGISTRY.keys()) + supp_model_str = "Supported detection models are:" + for i, m_name in enumerate(supported_models): + supp_model_str += "\n\t {}: {}".format(i, logger.color_text(m_name)) + if is_master_node: + logger.error(supp_model_str) + + pretrained = getattr(opts, "model.detection.pretrained", None) + if pretrained is not None: + pretrained = get_local_path(opts, path=pretrained) + model = load_pretrained_model(model=model, wt_loc=pretrained, opts=opts) + + freeze_norm_layers = getattr(opts, "model.detection.freeze_batch_norm", False) + if freeze_norm_layers: + model.freeze_norm_layers() + frozen_state, count_norm = check_frozen_norm_layer(model) + if count_norm > 0 and frozen_state and is_master_node: + logger.error( + "Something is wrong while freezing normalization layers. Please check" + ) + + if is_master_node: + logger.log("Normalization layers are frozen") + + return model + + +def common_detection_args(parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Detection arguments", description="Detection arguments" + ) + + group.add_argument( + "--model.detection.name", type=str, default=None, help="Model name" + ) + group.add_argument( + "--model.detection.n-classes", + type=int, + default=80, + help="Number of classes in the dataset", + ) + group.add_argument( + "--model.detection.pretrained", + type=str, + default=None, + help="Path of the pretrained model", + ) + group.add_argument( + "--model.detection.output-stride", + type=int, + default=None, + help="Output stride in classification network", + ) + group.add_argument( + "--model.detection.replace-stride-with-dilation", + action="store_true", + help="Replace stride with dilation", + ) + group.add_argument( + "--model.detection.freeze-batch-norm", + action="store_true", + help="Freeze batch norm layers", + ) + + return parser + + +@register_task_arguments(name="detection") +def arguments_detection(parser: argparse.ArgumentParser): + parser = common_detection_args(parser) + + # add detection specific arguments + for k, v in DETECT_MODEL_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + + return parser + + +# automatically import the models +models_dir = os.path.dirname(__file__) +for file in os.listdir(models_dir): + path = os.path.join(models_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("affnet.models.detection." + model_name) diff --git a/Adaptive Frequency Filters/affnet/models/detection/base_detection.py b/Adaptive Frequency Filters/affnet/models/detection/base_detection.py new file mode 100644 index 0000000..dd49cc2 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/models/detection/base_detection.py @@ -0,0 +1,146 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Dict, Tuple +import argparse + +from utils import logger + +from ..classification import BaseEncoder +from ... import parameter_list +from ...misc.init_utils import initialize_weights + +from collections import namedtuple + + +DetectionPredTuple = namedtuple( + typename="DetectionPredTuple", + field_names=("labels", "scores", "boxes", "masks"), + defaults=(None, None, None, None), +) + + +class BaseDetection(nn.Module): + """ + Base class for the task of object detection + """ + + def __init__(self, opts, encoder: BaseEncoder) -> None: + super().__init__() + assert isinstance(encoder, BaseEncoder) + self.encoder: BaseEncoder = encoder + self.n_detection_classes = getattr(opts, "model.detection.n_classes", 80) + + enc_conf = self.encoder.model_conf_dict + + enc_ch_l5_out_proj = _check_out_channels(enc_conf, "exp_before_cls") + enc_ch_l5_out = _check_out_channels(enc_conf, "layer5") + enc_ch_l4_out = _check_out_channels(enc_conf, "layer4") + enc_ch_l3_out = _check_out_channels(enc_conf, "layer3") + enc_ch_l2_out = _check_out_channels(enc_conf, "layer2") + enc_ch_l1_out = _check_out_channels(enc_conf, "layer1") + + self.enc_l5_channels = enc_ch_l5_out + self.enc_l5_channels_exp = enc_ch_l5_out_proj + self.enc_l4_channels = enc_ch_l4_out + self.enc_l3_channels = enc_ch_l3_out + self.enc_l2_channels = enc_ch_l2_out + self.enc_l1_channels = enc_ch_l1_out + + self.opts = opts + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Add model specific arguments""" + return parser + + @staticmethod + def reset_layer_parameters(layer, opts) -> None: + """Initialize weights of a given layer""" + initialize_weights(opts=opts, modules=layer.modules()) + + def get_trainable_parameters( + self, + weight_decay: float = 0.0, + no_decay_bn_filter_bias: bool = False, + *args, + **kwargs + ): + """Returns a list of trainable parameters""" + param_list = parameter_list( + named_parameters=self.named_parameters, + weight_decay=weight_decay, + no_decay_bn_filter_bias=no_decay_bn_filter_bias, + *args, + **kwargs + ) + return param_list, [1.0] * len(param_list) + + @staticmethod + def profile_layer(layer, input: Tensor) -> Tuple[Tensor, float, float]: + # profile a layer + block_params = block_macs = 0.0 + if isinstance(layer, nn.Sequential): + for layer_i in range(len(layer)): + input, layer_param, layer_macs = layer[layer_i].profile_module(input) + block_params += layer_param + block_macs += layer_macs + print( + "{:<15} \t {:<5}: {:>8.3f} M \t {:<5}: {:>8.3f} M".format( + layer[layer_i].__class__.__name__, + "Params", + round(layer_param / 1e6, 3), + "MACs", + round(layer_macs / 1e6, 3), + ) + ) + else: + input, layer_param, layer_macs = layer.profile_module(input) + block_params += layer_param + block_macs += layer_macs + print( + "{:<15} \t {:<5}: {:>8.3f} M \t {:<5}: {:>8.3f} M".format( + layer.__class__.__name__, + "Params", + round(layer_param / 1e6, 3), + "MACs", + round(layer_macs / 1e6, 3), + ) + ) + return input, block_params, block_macs + + def profile_model(self, input: Tensor): + """ + Child classes must implement this function to compute FLOPs and parameters + """ + raise NotImplementedError + + def dummy_input_and_label(self, batch_size: int) -> Dict: + """Create dummy input and labels for CI/CD purposes. Child classes must override it + if functionality is different. + """ + raise NotImplementedError + + +def _check_out_channels(config: Dict, layer_name: str) -> int: + enc_ch_l: Dict = config.get(layer_name, None) + if enc_ch_l is None or not enc_ch_l: + logger.error( + "Encoder does not define input-output mapping for {}: Got: {}".format( + layer_name, config + ) + ) + + enc_ch_l_out = enc_ch_l.get("out", None) + if enc_ch_l_out is None or not enc_ch_l_out: + logger.error( + "Output channels are not defined in {} of the encoder. Got: {}".format( + layer_name, enc_ch_l + ) + ) + + return enc_ch_l_out diff --git a/Adaptive Frequency Filters/affnet/models/detection/mask_rcnn.py b/Adaptive Frequency Filters/affnet/models/detection/mask_rcnn.py new file mode 100644 index 0000000..40ce821 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/models/detection/mask_rcnn.py @@ -0,0 +1,863 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse +from typing import Tuple, Dict, Union, Any, List + +import torch +from torch import nn, Tensor +from torch.nn import functional as F +from torchvision.models.detection.anchor_utils import AnchorGenerator + +# Faster and Mask-RCNN related imports +from torchvision.models.detection.mask_rcnn import MaskRCNN +from torchvision.ops import MultiScaleRoIAlign + +from utils import logger +from . import register_detection_models, BaseDetection +from .base_detection import DetectionPredTuple +from .utils.rcnn_utils import ( + FastRCNNConvFCHead, + RPNHead, + MaskRCNNHeads, + MaskRCNNPredictor, + FastRCNNPredictor, +) +from ... import parameter_list +from ...layers import ConvLayer, Identity +from ...models.classification import BaseEncoder + + +class MaskRCNNEncoder(nn.Module): + def __init__( + self, opts, encoder: BaseEncoder, output_strides: List, projection_channels: int + ) -> None: + use_fpn = not getattr(opts, "model.detection.mask_rcnn.disable_fpn", False) + super().__init__() + # set classifier and exp layers to Identity + encoder.conv_1x1_exp = Identity() + encoder.classifier = Identity() + + # add projection layers that projects encoder feature maps to `projection_channels` + backbone_proj_layers = nn.ModuleDict() + self.backbone_output_strides = sorted( + list({4, 8, 16, 32}.intersection(output_strides)) + ) + model_config = encoder.model_conf_dict + self.backbone_map = {} + fpn_proj_layers = nn.ModuleDict() if use_fpn else None + for os in self.backbone_output_strides: + if os == 4: + in_channels = model_config["layer2"]["out"] + backbone_os_str = "out_l2" + elif os == 8: + in_channels = model_config["layer3"]["out"] + backbone_os_str = "out_l3" + elif os == 16: + in_channels = model_config["layer4"]["out"] + backbone_os_str = "out_l4" + elif os == 32: + in_channels = model_config["layer5"]["out"] + backbone_os_str = "out_l5" + else: + raise NotImplementedError + + conv_layer = ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=projection_channels, + kernel_size=1, + use_norm=True, + use_act=False, + ) + backbone_proj_layers.add_module(str(os), conv_layer) + self.backbone_map[os] = backbone_os_str + + if use_fpn: + fpn_layer = ConvLayer( + opts=opts, + in_channels=projection_channels, + out_channels=projection_channels, + kernel_size=3, + use_norm=True, + use_act=False, + ) + fpn_proj_layers.add_module(str(os), fpn_layer) + + # add extra layers if desired output stride is greater than 32. + extra_layers = nn.ModuleDict() + extra_layer_os = sorted( + list((set(self.backbone_output_strides) ^ set(output_strides))) + ) + for os in extra_layer_os: + conv_layer = ConvLayer( + opts=opts, + in_channels=projection_channels, + out_channels=projection_channels, + kernel_size=3, + stride=2, + use_norm=True, + use_act=False, + ) + extra_layers.add_module(str(os), conv_layer) + self.encoder = encoder + self.backbone_proj_layers = backbone_proj_layers + self.fpn_proj_layers = fpn_proj_layers + self.use_fpn = use_fpn + self.extra_layers = extra_layers + self.out_channels = projection_channels + self.augmented_tensor = None + + def get_augmented_tensor(self) -> Tensor: + return self.augmented_tensor + + def forward(self, x: Tensor) -> Dict[str, Tensor]: + # extract features from the backbone network + enc_end_points: Dict = self.encoder.extract_end_points_all(x) + + self.augmented_tensor = enc_end_points.pop("augmented_tensor", None) + + outputs_backbone: Dict = {} + # project backbone features + for (os, enc_key_name) in self.backbone_map.items(): + x_proj = self.backbone_proj_layers[str(os)]( + enc_end_points.pop(enc_key_name) + ) + outputs_backbone[f"{os}"] = x_proj + + if self.fpn_proj_layers: + # FPN + last_os = self.backbone_output_strides[-1] + prev_fm = outputs_backbone[f"{last_os}"] + prev_fm = self.fpn_proj_layers[f"{last_os}"](prev_fm) + for os in self.backbone_output_strides[:-1][::-1]: + curr_fm = outputs_backbone[f"{os}"] + feat_shape = curr_fm.shape[-2:] + inner_top_down = F.interpolate(prev_fm, size=feat_shape, mode="nearest") + prev_fm = self.fpn_proj_layers[f"{os}"](curr_fm + inner_top_down) + outputs_backbone[f"{os}"] = prev_fm + + if self.extra_layers: + prev_os = self.backbone_output_strides[-1] + for os, extra_layer in self.extra_layers.items(): + x_proj = extra_layer(outputs_backbone[f"{prev_os}"]) + outputs_backbone[f"{os}"] = x_proj + prev_os = os + return outputs_backbone + + +@register_detection_models("mask_rcnn") +class MaskRCNNDetector(BaseDetection): + """ + This class implements a `Mask RCNN style object detector ` + + Args: + opts: command-line arguments + encoder (BaseEncoder): Encoder network (e.g., ResNet or MobileViT) + """ + + def __init__(self, opts, encoder: BaseEncoder): + super().__init__(opts, encoder) + default_norm = self.set_norm_layer_opts() + + output_strides = getattr( + opts, "model.detection.mask_rcnn.output_strides", [4, 8, 16, 32, 64] + ) + if len(output_strides) == 0: + logger.error( + "Please specify output strides for extracting backbone feature maps " + "using --model.detection.mask-rcnn.output-strides" + ) + output_strides = sorted(output_strides) + projection_channels = getattr( + opts, "model.detection.mask_rcnn.backbone_projection_channels", 256 + ) + + # anchor sizes and aspect ratios + anchor_sizes = getattr( + opts, "model.detection.mask_rcnn.anchor_sizes", [32, 64, 128, 256, 512] + ) + # convert to a tuples + if anchor_sizes is None: + logger.error("Anchor sizes can't be None") + elif len(anchor_sizes) != len(output_strides): + logger.error( + "Number of anchor sizes should be the same as the output stride. Got: {} and {}".format( + anchor_sizes, output_strides + ) + ) + elif isinstance(anchor_sizes, List) and isinstance(anchor_sizes[0], List): + # anchor sizes is a list of list. Convert to tuple + anchor_sizes = tuple([tuple(a_size) for a_size in anchor_sizes]) + elif isinstance(anchor_sizes, List) and isinstance(anchor_sizes[0], int): + # anchor sizes is a list of integers. Convert to tuple + anchor_sizes = tuple([(a_size,) for a_size in anchor_sizes]) + else: + raise NotImplementedError + + aspect_ratios = getattr( + opts, "model.detection.mask_rcnn.aspect_ratio", [0.5, 1.0, 2.0] + ) # ((0.5, 1.0, 2.0),) * len(anchor_sizes) + if aspect_ratios is None: + logger.error("Aspect ratios can't be None") + elif isinstance(aspect_ratios, (int, float)): + aspect_ratios = ((aspect_ratios,),) * len(anchor_sizes) + elif isinstance(aspect_ratios, List): + aspect_ratios = (tuple(aspect_ratios),) * len(anchor_sizes) + else: + raise NotImplementedError + + # feature map size for the bbox head + box_fm_size = getattr(opts, "model.detection.mask_rcnn.bbox_head_fm_size", 7) + mask_fm_size = getattr(opts, "model.detection.mask_rcnn.mask_head_fm_size", 14) + + # set-up the backbone + backbone = MaskRCNNEncoder( + opts, + encoder=encoder, + output_strides=output_strides, + projection_channels=projection_channels, + ) + + # create RPN anchor generator + rpn_anchor_generator = AnchorGenerator( + sizes=anchor_sizes, aspect_ratios=aspect_ratios + ) + + # create RPN Head + rpn_head = RPNHead( + opts=opts, + in_channels=projection_channels, + num_anchors=rpn_anchor_generator.num_anchors_per_location()[0], + conv_depth=2, + ) + + # box related parameters + representation_size = getattr( + opts, "model.detection.mask_rcnn.representation_size", 1024 + ) + output_strides_str = [str(os) for os in output_strides] + box_roi_pool = MultiScaleRoIAlign( + featmap_names=output_strides_str, output_size=box_fm_size, sampling_ratio=2 + ) + + box_fm_size_conv_layer = getattr( + opts, "model.detection.mask_rcnn.box_fm_size_conv_layer", [256] * 4 + ) + box_head = FastRCNNConvFCHead( + opts=opts, + input_size=(projection_channels, box_fm_size, box_fm_size), + conv_layers=box_fm_size_conv_layer, + fc_layers=[representation_size], + ) + + box_predictor = FastRCNNPredictor( + in_channels=representation_size, num_classes=self.n_detection_classes + ) + + # mask related parameters + mask_fm_size_conv_layer = getattr( + opts, "model.detection.mask_rcnn.mask_fm_size_conv_layer", [256] * 4 + ) + mask_dilation = getattr(opts, "model.detection.mask_rcnn.mask_dilation", 1) + mask_roi_pool = MultiScaleRoIAlign( + featmap_names=output_strides_str, output_size=mask_fm_size, sampling_ratio=2 + ) + + mask_dilation = mask_dilation + mask_head = MaskRCNNHeads( + opts=opts, + in_channels=projection_channels, + layers=mask_fm_size_conv_layer, + dilation=mask_dilation, + ) + + mask_predictor = MaskRCNNPredictor( + opts=opts, + in_channels=mask_fm_size_conv_layer[-1], + dim_reduced=256, + num_classes=self.n_detection_classes, + ) + + # RPN and box detection related hyper-parameters + rpn_pre_nms_top_n_train = getattr( + opts, "model.detection.mask_rcnn.rpn_pre_nms_top_n_train", 2000 + ) + rpn_pre_nms_top_n_test = getattr( + opts, "model.detection.mask_rcnn.rpn_pre_nms_top_n_test", 1000 + ) + rpn_post_nms_top_n_train = getattr( + opts, "model.detection.mask_rcnn.rpn_post_nms_top_n_train", 2000 + ) + rpn_post_nms_top_n_test = getattr( + opts, "model.detection.mask_rcnn.rpn_post_nms_top_n_test", 1000 + ) + rpn_nms_thresh = getattr(opts, "model.detection.mask_rcnn.rpn_nms_thresh", 0.7) + rpn_fg_iou_thresh = getattr( + opts, "model.detection.mask_rcnn.rpn_fg_iou_thresh", 0.7 + ) + rpn_bg_iou_thresh = getattr( + opts, "model.detection.mask_rcnn.rpn_bg_iou_thresh", 0.3 + ) + rpn_batch_size_per_image = getattr( + opts, "model.detection.mask_rcnn.rpn_batch_size_per_image", 256 + ) + rpn_positive_fraction = getattr( + opts, "model.detection.mask_rcnn.rpn_positive_fraction", 0.5 + ) + rpn_score_thresh = getattr( + opts, "model.detection.mask_rcnn.rpn_score_thresh", 0.0 + ) + + box_score_thresh = getattr( + opts, "model.detection.mask_rcnn.box_score_thresh", 0.05 + ) + box_nms_thresh = getattr(opts, "model.detection.mask_rcnn.box_nms_thresh", 0.5) + box_detections_per_img = getattr( + opts, "model.detection.mask_rcnn.box_detections_per_img", 100 + ) + box_fg_iou_thresh = getattr( + opts, "model.detection.mask_rcnn.box_fg_iou_thresh", 0.5 + ) + box_bg_iou_thresh = getattr( + opts, "model.detection.mask_rcnn.box_bg_iou_thresh", 0.5 + ) + box_batch_size_per_image = getattr( + opts, "model.detection.mask_rcnn.box_batch_size_per_image", 512 + ) + box_positive_fraction = getattr( + opts, "model.detection.mask_rcnn.box_positive_fraction", 0.25 + ) + + # kwargs = {"_skip_resize": True} + self.model = MaskRCNN( + backbone=backbone, + # num_classes=None, #self.n_detection_classes, + # In affnet, we don't use mean-std normalization + image_mean=[0.0] * 3, + image_std=[1.0] * 3, + # RPN parameters + rpn_anchor_generator=rpn_anchor_generator, + rpn_head=rpn_head, + rpn_pre_nms_top_n_train=rpn_pre_nms_top_n_train, + rpn_pre_nms_top_n_test=rpn_pre_nms_top_n_test, + rpn_post_nms_top_n_train=rpn_post_nms_top_n_train, + rpn_post_nms_top_n_test=rpn_post_nms_top_n_test, + rpn_nms_thresh=rpn_nms_thresh, + rpn_fg_iou_thresh=rpn_fg_iou_thresh, + rpn_bg_iou_thresh=rpn_bg_iou_thresh, + rpn_batch_size_per_image=rpn_batch_size_per_image, + rpn_positive_fraction=rpn_positive_fraction, + rpn_score_thresh=rpn_score_thresh, + # Box parameters + box_roi_pool=box_roi_pool, + box_head=box_head, + box_score_thresh=box_score_thresh, + box_nms_thresh=box_nms_thresh, + box_detections_per_img=box_detections_per_img, + box_fg_iou_thresh=box_fg_iou_thresh, + box_bg_iou_thresh=box_bg_iou_thresh, + box_batch_size_per_image=box_batch_size_per_image, + box_positive_fraction=box_positive_fraction, + bbox_reg_weights=None, + box_predictor=box_predictor, + # Mask parameters + mask_roi_pool=mask_roi_pool, + mask_head=mask_head, + mask_predictor=mask_predictor + # **kwargs + ) + + self.backbone_lr_multiplier = getattr( + opts, "model.detection.mask_rcnn.backbone_lr_multiplier", 1.0 + ) + del self.encoder + + self.reset_norm_layer_opts(default_norm=default_norm) + + def set_norm_layer_opts(self): + mask_rcnn_norm_layer = getattr( + self.opts, "model.detection.mask_rcnn.norm_layer", None + ) + if mask_rcnn_norm_layer is None: + logger.error("Please specify norm layer") + + default_norm = getattr(self.opts, "model.normalization.name", None) + setattr(self.opts, "model.normalization.name", mask_rcnn_norm_layer) + return default_norm + + def reset_norm_layer_opts(self, default_norm): + setattr(self.opts, "model.normalization.name", default_norm) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Add model specific arguments""" + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--model.detection.mask-rcnn.backbone-projection-channels", + type=int, + default=256, + help="Projection channels for the encoder in Mask-RCNN", + ) + + group.add_argument( + "--model.detection.mask-rcnn.backbone-lr-multiplier", + type=float, + default=1.0, + help="LR multiplier for MASK RCNN head", + ) + + group.add_argument( + "--model.detection.mask-rcnn.output-strides", + type=int, + nargs="+", + default=[4, 8, 16, 32, 64], + help="Extract backbone feature maps from these output strides. " + "If output stride is greater than 32, extra layers are added.", + ) + group.add_argument( + "--model.detection.mask-rcnn.anchor-sizes", + type=int, + nargs="+", + action="append", + default=[32, 64, 128, 256, 512], + help="Anchor sizes at each output stride", + ) + group.add_argument( + "--model.detection.mask-rcnn.aspect-ratio", + type=float, + nargs="+", + default=[0.5, 1.0, 2.0], + help="Aspect ratios. These are the same for all feature maps", + ) + + group.add_argument( + "--model.detection.mask-rcnn.bbox-head-fm-size", + type=int, + default=7, + help="Feature map size for the box head", + ) + group.add_argument( + "--model.detection.mask-rcnn.mask-head-fm-size", + type=int, + default=14, + help="Feature map size for the max head", + ) + group.add_argument( + "--model.detection.mask-rcnn.representation-size", + type=int, + default=1024, + help="Size of the intermediate representation in Mask RCNN", + ) + # box_fm_size_conv_layer = getattr(opts, "", [256] * 4) + group.add_argument( + "--model.detection.mask-rcnn.box-fm-size-conv-layer", + type=int, + nargs="+", + default=[256] * 4, + help="Feature dim of each Convolution layer in the Faster RCNN head. Defaults to [256, 256, 256, 256]", + ) + group.add_argument( + "--model.detection.mask-rcnn.mask-fm-size-conv-layer", + type=int, + nargs="+", + default=[256] * 4, + help="Feature dim of each Convolution layer in the Mask RCNN head. Defaults to [256, 256, 256, 256]", + ) + group.add_argument( + "--model.detection.mask-rcnn.mask-dilation", + type=int, + default=1, + help="Dilation rate in Mask RCNN head. Defaults to 1", + ) + + group.add_argument( + "--model.detection.mask-rcnn.rpn-pre-nms-top-n-train", + type=int, + default=2000, + help="Number of proposals to keep before applying NMS during training", + ) + + group.add_argument( + "--model.detection.mask-rcnn.rpn-pre-nms-top-n-test", + type=int, + default=1000, + help="Number of proposals to keep before applying NMS during test", + ) + + group.add_argument( + "--model.detection.mask-rcnn.rpn-post-nms-top-n-train", + type=int, + default=2000, + help="Number of proposals to keep after applying NMS during training", + ) + + group.add_argument( + "--model.detection.mask-rcnn.rpn-post-nms-top-n-test", + type=int, + default=1000, + help="Number of proposals to keep after applying NMS during test", + ) + + group.add_argument( + "--model.detection.mask-rcnn.rpn-nms-thresh", + type=float, + default=0.7, + help="NMS threshold used for postprocessing the RPN proposals", + ) + + group.add_argument( + "--model.detection.mask-rcnn.rpn-fg-iou-thresh", + type=float, + default=0.7, + help="minimum IoU between the anchor and the GT box so that they can be " + "considered as positive during training of the RPN.", + ) + + group.add_argument( + "--model.detection.mask-rcnn.rpn-bg-iou-thresh", + type=float, + default=0.7, + help="minimum IoU between the anchor and the GT box so that they can be " + "considered as negative during training of the RPN.", + ) + + group.add_argument( + "--model.detection.mask-rcnn.rpn-batch-size-per-image", + type=int, + default=256, + help="Number of anchors that are sampled during training of the RPN for computing the loss", + ) + + group.add_argument( + "--model.detection.mask-rcnn.rpn-positive-fraction", + type=float, + default=0.5, + help="Proportion of positive anchors in a mini-batch during training of the RPN", + ) + + group.add_argument( + "--model.detection.mask-rcnn.rpn-score-thresh", + type=float, + default=0.0, + help="During inference, only return proposals with a classification score greater than rpn_score_thresh", + ) + + # + group.add_argument( + "--model.detection.mask-rcnn.box-score-thresh", + type=float, + default=0.05, + help="During inference, only return proposals with a classification score greater than box_score_thresh", + ) + + group.add_argument( + "--model.detection.mask-rcnn.box-nms-thresh", + type=float, + default=0.5, + help="During inference, NMS threshold for the prediction head.", + ) + + group.add_argument( + "--model.detection.mask-rcnn.box-detections-per-img", + type=int, + default=100, + help="Maximum number of detections per image, for all classes", + ) + + group.add_argument( + "--model.detection.mask-rcnn.box-fg-iou-thresh", + type=float, + default=0.5, + help="Minimum IoU between the proposals and the GT box so that they can be considered as " + "positive during training of the classification head", + ) + + group.add_argument( + "--model.detection.mask-rcnn.box-bg-iou-thresh", + type=float, + default=0.5, + help="Minimum IoU between the proposals and the GT box so that they can be considered as " + "negative during training of the classification head", + ) + + group.add_argument( + "--model.detection.mask-rcnn.box-batch-size-per-image", + type=int, + default=512, + help="Number of proposals that are sampled during training of the classification head", + ) + + group.add_argument( + "--model.detection.mask-rcnn.box-positive-fraction", + type=float, + default=0.25, + help="Proportion of positive proposals in a mini-batch during training of the classification head", + ) + + group.add_argument( + "--model.detection.mask-rcnn.norm-layer", + type=str, + default=None, + help="Mask RCNN Norm layer", + ) + + group.add_argument( + "--model.detection.mask-rcnn.disable-fpn", + action="store_true", + help="Do not use FPN", + ) + return parser + + def reset_generalized_rcnn_transform(self, height, width): + self.model.transform.fixed_size = (width, height) + + def get_trainable_parameters( + self, + weight_decay: float = 0.0, + no_decay_bn_filter_bias: bool = False, + *args, + **kwargs, + ): + """Returns a list of trainable parameters""" + if self.backbone_lr_multiplier == 1.0: + return super(MaskRCNNDetector, self).get_trainable_parameters( + weight_decay=weight_decay, + no_decay_bn_filter_bias=no_decay_bn_filter_bias, + *args, + **kwargs, + ) + else: + + all_params = [] + all_params_lr = [] + + # pre-trained encoder parameters + backbone_param_list = parameter_list( + named_parameters=self.model.backbone.encoder.named_parameters, + weight_decay=weight_decay, + no_decay_bn_filter_bias=no_decay_bn_filter_bias, + module_name="model.backbone.encoder.", + *args, + **kwargs, + ) + + all_params.extend(backbone_param_list) + + all_params_lr.extend( + [self.backbone_lr_multiplier] * len(backbone_param_list) + ) + + if self.model.backbone.backbone_proj_layers: + # projection layer parameters + projection_param_list = parameter_list( + named_parameters=self.model.backbone.backbone_proj_layers.named_parameters, + weight_decay=weight_decay, + no_decay_bn_filter_bias=no_decay_bn_filter_bias, + module_name="model.backbone.backbone_proj_layers.", + *args, + **kwargs, + ) + + all_params.extend(projection_param_list) + + all_params_lr.extend([1.0] * len(projection_param_list)) + + if self.model.backbone.fpn_proj_layers: + # projection layer parameters + fpn_projection_param_list = parameter_list( + named_parameters=self.model.backbone.fpn_proj_layers.named_parameters, + weight_decay=weight_decay, + no_decay_bn_filter_bias=no_decay_bn_filter_bias, + module_name="model.backbone.fpn_proj_layers.", + *args, + **kwargs, + ) + + all_params.extend(fpn_projection_param_list) + + all_params_lr.extend([1.0] * len(fpn_projection_param_list)) + + if self.model.backbone.extra_layers: + # extra layer parameters + extra_layer_param_list = parameter_list( + named_parameters=self.model.backbone.extra_layers.named_parameters, + weight_decay=weight_decay, + no_decay_bn_filter_bias=no_decay_bn_filter_bias, + module_name="model.backbone.extra_layers.", + *args, + **kwargs, + ) + + all_params.extend(extra_layer_param_list) + + all_params_lr.extend([1.0] * len(extra_layer_param_list)) + + # rpn parameters + rpn_param_list = parameter_list( + named_parameters=self.model.rpn.named_parameters, + weight_decay=weight_decay, + no_decay_bn_filter_bias=no_decay_bn_filter_bias, + module_name="model.rpn.", + *args, + **kwargs, + ) + + all_params.extend(rpn_param_list) + all_params_lr.extend([1.0] * len(rpn_param_list)) + + # ROI head params + roi_param_list = parameter_list( + named_parameters=self.model.roi_heads.named_parameters, + weight_decay=weight_decay, + no_decay_bn_filter_bias=no_decay_bn_filter_bias, + module_name="model.roi_heads.", + *args, + **kwargs, + ) + + all_params.extend(roi_param_list) + all_params_lr.extend([1.0] * len(roi_param_list)) + + return all_params, all_params_lr + + def forward( + self, x: Dict, *args, **kwargs + ) -> Union[Tuple[Tensor, ...], Tuple[Any, ...], Dict]: + + if isinstance(x, Dict): + input_tensor = x["image"] + input_labels = x["label"] + else: + raise NotImplementedError( + "Input to MaskRCNN should be a Dict of List of Tensors" + ) + + assert isinstance(input_tensor, List) + assert isinstance(input_labels, List) + + in_channels, in_height, in_width = input_tensor[0].shape + + self.reset_generalized_rcnn_transform(height=in_height, width=in_width) + + # The mask rcnn model expects labels, since it computes the loss. + outputs = self.model(input_tensor, targets=input_labels) + + if not self.training: + detections = [] + for i, elem in enumerate(outputs): + # We must normalize by image size, since this is what the downstream + # evaluator expects. + elem["boxes"][:, 0::2] /= input_tensor[i].shape[2] + elem["boxes"][:, 1::2] /= input_tensor[i].shape[1] + + # predicted masks are in [N, 1, H, W] format + # for evaluation, we need them in [N, H, W] format + masks = elem["masks"] + # [N, 1, H, W] --> [N, H, W] + masks = masks.squeeze(1) + + elem_detections = DetectionPredTuple( + labels=elem["labels"], + scores=elem["scores"], + boxes=elem["boxes"], + masks=masks, + ) + detections.append(elem_detections) + return {"detections": detections} + + if hasattr(self.model.backbone, "get_augmented_tensor"): + outputs["augmented_tensor"] = self.model.backbone.get_augmented_tensor() + + return outputs + + @torch.no_grad() + def predict(self, x: Tensor, *args, **kwargs) -> DetectionPredTuple: + """Predict the bounding boxes given an image tensor""" + assert isinstance(x, Tensor) and x.ndim == 4, "Expected 4D tensor as an input" + + bsz, channels, in_height, in_width = x.shape + if bsz != 1: + logger.error( + "Prediction is supported with a batch size of 1 in {}".format( + self.__class__.__name__ + ) + ) + + self.reset_generalized_rcnn_transform(height=in_height, width=in_width) + + outputs = self.model(x) + + if isinstance(outputs, List) and len(outputs) == 1: + outputs = outputs[0] + + if isinstance(outputs, Dict) and {"boxes", "labels", "scores"}.issubset( + outputs.keys() + ): + # resize the boxes + outputs["boxes"][:, 0::2] /= in_width + outputs["boxes"][:, 1::2] /= in_height + + # predicted masks are in [N, 1, H, W] format + # for evaluation, we need them in [N, H, W] format + masks = outputs["masks"] + # [N, 1, H, W] --> [N, H, W] + masks = masks.squeeze(1) + + detections = DetectionPredTuple( + labels=outputs["labels"], + scores=outputs["scores"], + boxes=outputs["boxes"], + masks=masks, + ) + return detections + else: + logger.error( + "Output should be a dict with boxes, scores, and labels as keys. Got: {}".format( + type(outputs) + ) + ) + + def dummy_input_and_label(self, batch_size: int) -> Dict: + """Create dummy input and labels for CI/CD purposes.""" + img_channels = 3 + height = 320 + width = 320 + n_classes = 80 + + # GT boxes have the same shape as anchors. So, we use anchors as GT boxes + n_boxes = 1 + + gt_boxes = torch.tensor([2, 20, 3, 40]).reshape(-1, 4).float() + gt_box_labels = torch.randint( + low=0, + high=n_classes, + size=(n_boxes,), + dtype=torch.long, + ) + + img_tensor = torch.randn(img_channels, height, width, dtype=torch.float) + labels = { + "box_labels": gt_box_labels, + "box_coordinates": gt_boxes, + } + + return { + "samples": { + "image": [img_tensor] * batch_size, + "label": [ + { + "labels": gt_box_labels, + "boxes": gt_boxes, + "masks": torch.zeros(1, height, width, dtype=torch.long), + } + ] + * batch_size, + }, + "targets": labels, + } diff --git a/Adaptive Frequency Filters/affnet/models/detection/ssd.py b/Adaptive Frequency Filters/affnet/models/detection/ssd.py new file mode 100644 index 0000000..15e6657 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/models/detection/ssd.py @@ -0,0 +1,686 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import copy + +import torch +from torch import nn, Tensor +from utils import logger +import argparse +from typing import Optional, Tuple, Dict, Union, Any, List +from torchvision.ops import batched_nms +from torch.nn import functional as F +import math + +from affnet.anchor_generator import build_anchor_generator +from affnet.matcher_det import build_matcher +from utils.common_utils import is_coreml_conversion + +from . import register_detection_models +from ... import parameter_list + +from .base_detection import BaseDetection, DetectionPredTuple +from ...layers import ConvLayer, SeparableConv, AdaptiveAvgPool2d +from ...modules import SSDHead, SSDInstanceHead +from ...models.classification import BaseEncoder +from ...misc.init_utils import initialize_conv_layer +from ...misc.profiler import module_profile + + +@register_detection_models("ssd") +class SingleShotMaskDetector(BaseDetection): + """ + This class implements a `Single Shot Object Detector `_ + + Args: + opts: command-line arguments + encoder (BaseEncoder): Encoder network (e.g., ResNet or MobileViT) + """ + + coordinates = 4 # 4 coordinates (x1, y1, x2, y2) or (x, y, w, h) + + def __init__(self, opts, encoder: BaseEncoder) -> None: + + anchor_gen_name = getattr(opts, "anchor_generator.name", None) + if anchor_gen_name is None or anchor_gen_name != "ssd": + logger.error("For SSD, we need --anchor-generator.name to be ssd") + anchor_box_generator = build_anchor_generator(opts=opts) + + output_strides_aspect_ratio = anchor_box_generator.output_strides_aspect_ratio + output_strides = list(output_strides_aspect_ratio.keys()) + anchors_aspect_ratio = list(output_strides_aspect_ratio.values()) + + n_os = len(output_strides) + + if getattr(opts, "matcher.name") != "ssd": + logger.error("For SSD, we need --matcher.name as ssd") + + super().__init__(opts=opts, encoder=encoder) + + # delete layers that are not required in detection network + self.encoder.classifier = None + self.encoder.conv_1x1_exp = None + + proj_channels = getattr( + opts, "model.detection.ssd.proj_channels", [512, 256, 256, 128, 128, 64] + ) + + proj_channels = proj_channels + [128] * (n_os - len(proj_channels)) + + if n_os != len(anchors_aspect_ratio) != len(proj_channels): + logger.error( + "SSD model requires anchors to be defined for feature maps from each output stride. Also" + "len(anchors_aspect_ratio) == len(output_strides) == len(proj_channels). " + "Got len(output_strides)={}, len(anchors_aspect_ratio)={}, len(proj_channels)={}." + " Please specify correct arguments using following arguments: " + "\n--model.detection.ssd.anchors-aspect-ratio " + "\n--model.detection.ssd.output-strides" + "\n--model.detection.ssd.proj-channels".format( + n_os, len(anchors_aspect_ratio), len(proj_channels) + ) + ) + extra_layers = {} + enc_channels_list = [] + in_channels = self.enc_l5_channels + + extra_proj_list = [256] * (len(output_strides) - len(proj_channels)) + proj_channels = proj_channels + extra_proj_list + for idx, os in enumerate(output_strides): + out_channels = proj_channels[idx] + if os == 8: + enc_channels_list.append(self.enc_l3_channels) + elif os == 16: + enc_channels_list.append(self.enc_l4_channels) + elif os == 32: + enc_channels_list.append(self.enc_l5_channels) + elif os > 32 and os != -1: + extra_layers["os_{}".format(os)] = SeparableConv( + opts=opts, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + use_act=True, + use_norm=True, + stride=2, + ) + enc_channels_list.append(out_channels) + in_channels = out_channels + elif os == -1: + extra_layers["os_{}".format(os)] = nn.Sequential( + AdaptiveAvgPool2d(output_size=1), + ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + use_act=True, + use_norm=False, + ), + ) + enc_channels_list.append(out_channels) + in_channels = out_channels + else: + raise NotImplementedError + self.extra_layers = None if not extra_layers else nn.ModuleDict(extra_layers) + if self.extra_layers is not None: + self.reset_layers(module=self.extra_layers) + + self.fpn = None + if getattr(opts, "model.detection.ssd.use_fpn", False): + from ...modules import FeaturePyramidNetwork + + fpn_channels = getattr(opts, "model.detection.ssd.fpn_out_channels", 256) + self.fpn = FeaturePyramidNetwork( + opts=opts, + in_channels=enc_channels_list, + output_strides=output_strides, + out_channels=fpn_channels, + ) + # update the enc_channels_list + enc_channels_list = [fpn_channels] * len(output_strides) + # for FPN, we do not need to do projections + proj_channels = enc_channels_list + + # Anchor box related parameters + self.conf_threshold = getattr(opts, "model.detection.ssd.conf_threshold", 0.01) + self.nms_threshold = getattr(opts, "model.detection.ssd.nms_iou_threshold", 0.5) + self.top_k = getattr(opts, "model.detection.ssd.top_k", 400) + self.objects_per_image = getattr( + opts, "model.detection.ssd.objects_per_image", 200 + ) + + self.anchor_box_generator = anchor_box_generator + + anchors_aspect_ratio = self.anchor_box_generator.num_anchors_per_os() + + # Create SSD detection and classification heads + anchor_steps = self.anchor_box_generator.step + + self.ssd_heads = nn.ModuleList() + + for os, in_dim, proj_dim, n_anchors, step in zip( + output_strides, + enc_channels_list, + proj_channels, + anchors_aspect_ratio, + anchor_steps, + ): + self.ssd_heads += [ + SSDHead( + opts=opts, + in_channels=in_dim, + n_classes=self.n_detection_classes, + n_coordinates=self.coordinates, + n_anchors=n_anchors, + proj_channels=proj_dim, + kernel_size=3 if os != -1 else 1, + stride=step, + ) + ] + + self.anchors_aspect_ratio = anchors_aspect_ratio + self.output_strides = output_strides + + self.match_prior = build_matcher(opts=opts) + self.step = self.anchor_box_generator.step + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--model.detection.ssd.anchors-aspect-ratio", + type=int, + nargs="+", + action="append", + default=[[2, 3]] * 4, + help="Anchors aspect ratio in each feature map obtained at different output strides.", + ) + group.add_argument( + "--model.detection.ssd.output-strides", + type=int, + nargs="+", + default=[16, 32, 64, 128], + help="Extract feature maps from these output strides.", + ) + group.add_argument( + "--model.detection.ssd.proj-channels", + type=int, + nargs="+", + default=[512] * 4, + help="Projection channels for feature map obtained at each output stride", + ) + + # depreciated + group.add_argument( + "--model.detection.ssd.min-box-size", + type=float, + default=None, + help="Min. box size. Value between 0 and 1. Good default value is 0.1", + ) + group.add_argument( + "--model.detection.ssd.max-box-size", + type=float, + default=None, + help="Max. box size. Value between 0 and 1. Good default value is 1.05", + ) + + # Depreciated + group.add_argument( + "--model.detection.ssd.center-variance", + type=float, + default=None, + help="Center variance.", + ) + group.add_argument( + "--model.detection.ssd.size-variance", + type=float, + default=None, + help="Size variance.", + ) + group.add_argument( + "--model.detection.ssd.iou-threshold", + type=float, + default=None, + help="IOU Threshold.", + ) + + # inference related arguments + group.add_argument( + "--model.detection.ssd.conf-threshold", + type=float, + default=0.01, + help="Confidence threshold. For evaluation on COCO, set to 0.01, so that we can compute mAP", + ) + group.add_argument( + "--model.detection.ssd.top-k", + type=int, + default=400, + help="Keep only top-k objects before NMS", + ) + group.add_argument( + "--model.detection.ssd.objects-per-image", + type=int, + default=200, + help="Keep only these many objects after NMS", + ) + group.add_argument( + "--model.detection.ssd.nms-iou-threshold", + type=float, + default=0.5, + help="NMS IoU threshold ", + ) + + # FPN + group.add_argument( + "--model.detection.ssd.fpn-out-channels", + type=int, + default=256, + help="Number of output channels in FPN", + ) + group.add_argument( + "--model.detection.ssd.use-fpn", + action="store_true", + help="Use SSD with FPN", + ) + + return parser + + @staticmethod + def reset_layers(module) -> None: + for layer in module.modules(): + if isinstance(layer, nn.Conv2d): + initialize_conv_layer(module=layer, init_method="xavier_uniform") + + @staticmethod + def process_anchors_ar(anchor_ar: List) -> List: + assert isinstance(anchor_ar, list) + new_ar = [] + for ar in anchor_ar: + if ar in new_ar: + continue + new_ar.append(ar) + return new_ar + + def get_backbone_features(self, x: Tensor) -> Dict[str, Tensor]: + # extract features from the backbone network + enc_end_points: Dict = self.encoder.extract_end_points_all(x) + + end_points: Dict = dict() + for idx, os in enumerate(self.output_strides): + if os == 8: + end_points["os_{}".format(os)] = enc_end_points.pop("out_l3") + elif os == 16: + end_points["os_{}".format(os)] = enc_end_points.pop("out_l4") + elif os == 32: + end_points["os_{}".format(os)] = enc_end_points.pop("out_l5") + else: + x = end_points["os_{}".format(self.output_strides[idx - 1])] + end_points["os_{}".format(os)] = self.extra_layers["os_{}".format(os)]( + x + ) + + if self.fpn is not None: + # apply Feature Pyramid Network + end_points = self.fpn(end_points) + + return end_points + + def ssd_forward( + self, + end_points: Dict[str, Tensor], + device: Optional[torch.device] = torch.device("cpu"), + *args, + **kwargs + ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, ...]]: + + locations = [] + confidences = [] + anchors = [] + + for os, ssd_head in zip(self.output_strides, self.ssd_heads): + x = end_points["os_{}".format(os)] + fm_h, fm_w = x.shape[2:] + loc, pred = ssd_head(x) + + locations.append(loc) + confidences.append(pred) + + anchors_fm_ctr = self.anchor_box_generator( + fm_height=fm_h, fm_width=fm_w, fm_output_stride=os, device=device + ) + anchors.append(anchors_fm_ctr) + + locations = torch.cat(locations, dim=1) + confidences = torch.cat(confidences, dim=1) + + anchors = torch.cat(anchors, dim=0) + anchors = anchors.unsqueeze(dim=0) + + return confidences, locations, anchors + + def forward( + self, x: Union[Tensor, Dict] + ) -> Union[Tuple[Tensor, ...], Tuple[Any, ...], Dict]: + if isinstance(x, Dict): + input_tensor = x["image"] + elif isinstance(x, Tensor): + input_tensor = x + else: + raise NotImplementedError( + "Input to SSD should be either a Tensor or a Dict of Tensors" + ) + + device = input_tensor.device + backbone_end_points: Dict = self.get_backbone_features(input_tensor) + + if not is_coreml_conversion(self.opts): + confidences, locations, anchors = self.ssd_forward( + end_points=backbone_end_points, device=device + ) + + output_dict = {"scores": confidences, "boxes": locations} + + if not self.training: + # compute the detection results during evaluation + scores = nn.Softmax(dim=-1)(confidences) + boxes = self.match_prior.convert_to_boxes( + pred_locations=locations, anchors=anchors + ) + + detections = self.postprocess_detections(boxes=boxes, scores=scores) + output_dict["detections"] = detections + + return output_dict + else: + return self.ssd_forward(end_points=backbone_end_points, is_prediction=False) + + @torch.no_grad() + def predict(self, x: Tensor, *args, **kwargs) -> DetectionPredTuple: + """Predict the bounding boxes given an image tensor""" + bsz, channels, width, height = x.shape + if bsz != 1: + logger.error( + "Prediction is supported with a batch size of 1 in {}".format( + self.__class__.__name__ + ) + ) + + device = x.device + enc_end_points: Dict = self.get_backbone_features(x) + confidences, locations, anchors = self.ssd_forward( + end_points=enc_end_points, device=device + ) + + scores = nn.Softmax(dim=-1)(confidences) + + boxes = self.match_prior.convert_to_boxes( + pred_locations=locations, anchors=anchors + ) + detections = self.postprocess_detections(boxes=boxes, scores=scores)[0] + return detections + + @torch.no_grad() + def postprocess_detections( + self, boxes: Tensor, scores: Tensor + ) -> List[DetectionPredTuple]: + """Post process detections, including NMS""" + # boxes [B, N, 4] + # scores [B, N] + # labels [B, N] + + batch_size = boxes.shape[0] + n_classes = scores.shape[-1] + + device = boxes.device + box_dtype = boxes.dtype + scores_dtype = scores.dtype + + results = [] + for b_id in range(batch_size): + object_labels = [] + object_boxes = [] + object_scores = [] + + for class_index in range(1, n_classes): + probs = scores[b_id, :, class_index] + mask = probs > self.conf_threshold + probs = probs[mask] + if probs.size(0) == 0: + continue + masked_boxes = boxes[b_id, mask, :] + + # keep only top-k indices + num_topk = min(self.top_k, probs.size(0)) + probs, idxs = probs.topk(num_topk) + masked_boxes = masked_boxes[idxs, ...] + + object_boxes.append(masked_boxes) + object_scores.append(probs) + object_labels.append( + torch.full_like( + probs, fill_value=class_index, dtype=torch.int64, device=device + ) + ) + + if len(object_scores) == 0: + output = DetectionPredTuple( + labels=torch.empty(0, device=device, dtype=torch.long), + scores=torch.empty(0, device=device, dtype=scores_dtype), + boxes=torch.empty(0, 4, device=device, dtype=box_dtype), + ) + else: + # concatenate all results + object_scores = torch.cat(object_scores, dim=0) + object_boxes = torch.cat(object_boxes, dim=0) + object_labels = torch.cat(object_labels, dim=0) + + # non-maximum suppression + keep = batched_nms( + object_boxes, object_scores, object_labels, self.nms_threshold + ) + keep = keep[: self.objects_per_image] + + output = DetectionPredTuple( + labels=object_labels[keep], + scores=object_scores[keep], + boxes=object_boxes[keep], + ) + results.append(output) + return results + + def profile_backbone(self, x: Tensor) -> Tuple[Dict[str, Tensor], float, float]: + params, macs = 0.0, 0.0 + enc_end_points, p, m = self.encoder.profile_model(x, is_classification=False) + params += p + macs += m + + end_points = dict() + for idx, os in enumerate(self.output_strides): + if os == 8: + end_points["os_{}".format(os)] = enc_end_points.pop("out_l3") + elif os == 16: + end_points["os_{}".format(os)] = enc_end_points.pop("out_l4") + elif os == 32: + end_points["os_{}".format(os)] = enc_end_points.pop("out_l5") + else: + x = end_points["os_{}".format(self.output_strides[idx - 1])] + x, p, m = module_profile( + module=self.extra_layers["os_{}".format(os)], x=x + ) + end_points["os_{}".format(os)] = x + + params += p + macs += m + + if self.fpn is not None: + end_points, p, m = self.fpn.profile_module(end_points) + params += p + macs += m + + enc_str = ( + logger.text_colors["logs"] + + logger.text_colors["bold"] + + "FPN " + + logger.text_colors["end_color"] + ) + print("{:>45}".format(enc_str)) + print( + "{:<15} \t {:<5}: {:>8.3f} M \t {:<5}: {:>8.3f} M".format( + self.fpn.__class__.__name__, + "Params", + round(p / 1e6, 3), + "MACs", + round(m / 1e6, 3), + ) + ) + logger.singe_dash_line() + return end_points, params, macs + + def profile_model(self, input: Tensor) -> None: + """ + This function computes layer-wise FLOPs and parameters for SSD + + .. note:: + Model profiling is for reference only and may contain errors as it relies heavily on user + to implement the underlying functions accurately. + """ + overall_params, overall_macs = 0.0, 0.0 + input_fvcore = input.clone() + + logger.log("Model statistics for an input of size {}".format(input.size())) + logger.double_dash_line(dashes=65) + print("{:>35} Summary".format(self.__class__.__name__)) + logger.double_dash_line(dashes=65) + + # profile encoder + enc_str = ( + logger.text_colors["logs"] + + logger.text_colors["bold"] + + "Encoder " + + logger.text_colors["end_color"] + ) + print("{:>45}".format(enc_str)) + backbone_end_points, encoder_params, encoder_macs = self.profile_backbone( + x=input + ) + + ssd_head_params = ssd_head_macs = 0.0 + for os, ssd_head in zip(self.output_strides, self.ssd_heads): + _, p, m = module_profile( + module=ssd_head, x=backbone_end_points["os_{}".format(os)] + ) + ssd_head_params += p + ssd_head_macs += m + + overall_params += encoder_params + ssd_head_params + overall_macs += encoder_macs + ssd_head_macs + + ssd_str = ( + logger.text_colors["logs"] + + logger.text_colors["bold"] + + "SSD " + + logger.text_colors["end_color"] + ) + print("{:>45}".format(ssd_str)) + + print( + "{:<15} \t {:<5}: {:>8.3f} M \t {:<5}: {:>8.3f} M".format( + self.__class__.__name__, + "Params", + round(ssd_head_params / 1e6, 3), + "MACs", + round(ssd_head_macs / 1e6, 3), + ) + ) + + logger.double_dash_line(dashes=65) + print("{:<20} = {:>8.3f} M".format("Overall parameters", overall_params / 1e6)) + overall_params_py = sum([p.numel() for p in self.parameters()]) + print( + "{:<20} = {:>8.3f} M".format( + "Overall parameters (sanity check)", overall_params_py / 1e6 + ) + ) + + # Counting Addition and Multiplication as 1 operation + print( + "{:<20} = {:>8.3f} M".format( + "Overall MACs (theoretical)", overall_macs / 1e6 + ) + ) + + # compute flops using FVCore + try: + # compute flops using FVCore also + from fvcore.nn import FlopCountAnalysis + + flop_analyzer = FlopCountAnalysis(self.eval(), input_fvcore) + flop_analyzer.unsupported_ops_warnings(False) + flop_analyzer.uncalled_modules_warnings(False) + flops_fvcore = flop_analyzer.total() + print( + "{:<20} = {:>8.3f} M".format( + "Overall MACs (FVCore)**", flops_fvcore / 1e6 + ) + ) + print( + "\n** Theoretical and FVCore MACs may vary as theoretical MACs do not account " + "for certain operations which may or may not be accounted in FVCore" + ) + except Exception: + pass + + print("Note: Theoretical MACs depends on user-implementation. Be cautious") + + logger.double_dash_line(dashes=65) + + def dummy_input_and_label(self, batch_size: int) -> Dict: + """Create dummy input and labels for CI/CD purposes.""" + img_channels = 3 + height = 320 + width = 320 + n_classes = 80 + + def generate_anchors(height, width): + """Generate anchors **on-the-fly** based on the input resolution.""" + anchors = [] + for output_stride in self.output_strides: + if output_stride == -1: + fm_width = fm_height = 1 + else: + fm_width = int(math.ceil(width / output_stride)) + fm_height = int(math.ceil(height / output_stride)) + fm_anchor = self.anchor_box_generator( + fm_height=fm_height, + fm_width=fm_width, + fm_output_stride=output_stride, + ) + anchors.append(fm_anchor) + anchors = torch.cat(anchors, dim=0) + return anchors + + # GT boxes have the same shape as anchors. So, we use anchors as GT boxes + gt_boxes = generate_anchors(height=height, width=width) + gt_boxes = gt_boxes.unsqueeze(0).expand(batch_size, -1, -1) + + gt_box_labels = torch.randint( + low=0, + high=n_classes, + size=(batch_size, gt_boxes.shape[1]), + dtype=torch.long, + ) + + img_tensor = torch.randn( + batch_size, img_channels, height, width, dtype=torch.float + ) + labels = { + "box_labels": gt_box_labels, + "box_coordinates": gt_boxes, + } + + return {"samples": img_tensor, "targets": labels} diff --git a/Adaptive Frequency Filters/affnet/models/detection/utils/__init__.py b/Adaptive Frequency Filters/affnet/models/detection/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Adaptive Frequency Filters/affnet/models/detection/utils/rcnn_utils.py b/Adaptive Frequency Filters/affnet/models/detection/utils/rcnn_utils.py new file mode 100644 index 0000000..9070f2f --- /dev/null +++ b/Adaptive Frequency Filters/affnet/models/detection/utils/rcnn_utils.py @@ -0,0 +1,264 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from typing import Tuple, List + +import torch +from torch import nn, Tensor + +from ....layers import ( + ConvLayer, + LinearLayer, + TransposeConvLayer, + get_normalization_layer, +) +from ....misc.init_utils import initialize_conv_layer, initialize_fc_layer + + +# Below classes are adapted from Torchvision version=0.12 to make the code compatible with previous torch versions. + + +class FastRCNNConvFCHead(nn.Sequential): + def __init__( + self, + opts, + input_size: Tuple[int, int, int], + conv_layers: List[int], + fc_layers: List[int], + *args, + **kwargs, + ): + """ + Args: + input_size (Tuple[int, int, int]): the input size in CHW format. + conv_layers (list): feature dimensions of each Convolution layer + fc_layers (list): feature dimensions of each FCN layer + """ + in_channels, in_height, in_width = input_size + + blocks = [] + previous_channels = in_channels + for current_channels in conv_layers: + blocks.extend( + [ + ConvLayer( + opts, + in_channels=previous_channels, + out_channels=current_channels, + kernel_size=3, + stride=1, + use_norm=False, + use_act=False, + ), + replace_syncbn_with_syncbnfp32(opts, num_features=current_channels), + nn.ReLU(inplace=False), + ] + ) + previous_channels = current_channels + blocks.append(nn.Flatten()) + previous_channels = previous_channels * in_height * in_width + + for current_channels in fc_layers: + blocks.append(LinearLayer(previous_channels, current_channels, bias=True)) + blocks.append(nn.ReLU(inplace=True)) + previous_channels = current_channels + + super().__init__(*blocks) + for layer in self.modules(): + if isinstance(layer, nn.Conv2d): + initialize_conv_layer(module=layer, init_method="kaiming_normal") + elif isinstance(layer, LinearLayer): + initialize_fc_layer(module=layer, init_method="kaiming_uniform") + + +class RPNHead(nn.Module): + """ + Adds a simple RPN Head with classification and regression heads + + Args: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + conv_depth (int, optional): number of convolutions + """ + + def __init__(self, opts, in_channels: int, num_anchors: int, conv_depth=1) -> None: + super().__init__() + convs = [] + for _ in range(conv_depth): + convs.extend( + [ + ConvLayer( + opts, + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=1, + use_norm=False, + use_act=False, + bias=False, + ), + replace_syncbn_with_syncbnfp32(opts, num_features=in_channels), + nn.ReLU(inplace=False), + ] + ) + self.conv = nn.Sequential(*convs) + self.cls_logits = ConvLayer( + opts, + in_channels=in_channels, + out_channels=num_anchors, + kernel_size=1, + stride=1, + use_norm=False, + use_act=False, + bias=True, + ) + self.bbox_pred = ConvLayer( + opts, + in_channels=in_channels, + out_channels=num_anchors * 4, + kernel_size=1, + stride=1, + use_act=False, + use_norm=False, + bias=True, + ) + + for layer in self.modules(): + if isinstance(layer, nn.Conv2d): + initialize_conv_layer(module=layer, init_method="normal", std_val=0.01) + + def forward(self, x: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: + logits = [] + bbox_reg = [] + for feature in x: + t = self.conv(feature) + logits.append(self.cls_logits(t)) + bbox_reg.append(self.bbox_pred(t)) + return logits, bbox_reg + + +class MaskRCNNHeads(nn.Sequential): + def __init__(self, opts, in_channels: int, layers: List, dilation: int): + """ + Args: + in_channels (int): number of input channels + layers (list): feature dimensions of each FCN layer + dilation (int): dilation rate of kernel + norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None + """ + blocks = [] + next_feature = in_channels + for layer_features in layers: + blocks.extend( + [ + ConvLayer( + opts=opts, + in_channels=next_feature, + out_channels=layer_features, + kernel_size=3, + stride=1, + dilation=dilation, + use_norm=False, + use_act=False, + bias=False, + ), + replace_syncbn_with_syncbnfp32( + opts=opts, num_features=layer_features + ), + nn.ReLU(inplace=False), + ] + ) + next_feature = layer_features + + super().__init__(*blocks) + + for layer in self.modules(): + if isinstance(layer, nn.Conv2d): + initialize_conv_layer(module=layer, init_method="kaiming_normal") + + +class MaskRCNNPredictor(nn.Sequential): + def __init__( + self, opts, in_channels: int, dim_reduced: int, num_classes: int + ) -> None: + super().__init__( + *[ + TransposeConvLayer( + opts, + in_channels=in_channels, + out_channels=dim_reduced, + kernel_size=2, + stride=2, + padding=0, + output_padding=0, + use_norm=False, + use_act=False, + bias=False, + groups=1, + ), + replace_syncbn_with_syncbnfp32(opts, num_features=dim_reduced), + nn.ReLU(inplace=False), + ConvLayer( + opts, + in_channels=dim_reduced, + out_channels=num_classes, + kernel_size=1, + stride=1, + bias=True, + use_norm=False, + use_act=False, + ), + ] + ) + + for layer in self.modules(): + if isinstance(layer, (nn.Conv2d, nn.ConvTranspose2d)): + initialize_conv_layer(module=layer, init_method="kaiming_normal") + + +class FastRCNNPredictor(nn.Module): + """ + Standard classification + bounding box regression layers + for Fast R-CNN. + + Args: + in_channels (int): number of input channels + num_classes (int): number of output classes (including background) + """ + + def __init__(self, in_channels: int, num_classes: int) -> None: + super().__init__() + self.cls_score = LinearLayer(in_channels, num_classes, bias=True) + self.bbox_pred = LinearLayer(in_channels, num_classes * 4, bias=True) + + for layer in self.modules(): + if isinstance(layer, LinearLayer): + initialize_fc_layer(module=layer, init_method="kaiming_uniform") + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + if x.dim() == 4: + torch._assert( + list(x.shape[2:]) == [1, 1], + f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}", + ) + x = x.flatten(start_dim=1) + scores = self.cls_score(x) + bbox_deltas = self.bbox_pred(x) + + return scores, bbox_deltas + + +def replace_syncbn_with_syncbnfp32(opts, num_features: int) -> nn.Module: + # Sync-BN with 0 batch size does not work well with AMP. To avoid that, + # we replace all sync_bn in mask rcnn head with FP32 ones. + norm_layer = getattr(opts, "model.normalization.name", None) + + if norm_layer.find("sync") > -1: + return get_normalization_layer( + opts, num_features=num_features, norm_type="sync_batch_norm_fp32" + ) + else: + return get_normalization_layer(opts=opts, num_features=num_features) diff --git a/Adaptive Frequency Filters/affnet/models/segmentation/__init__.py b/Adaptive Frequency Filters/affnet/models/segmentation/__init__.py new file mode 100644 index 0000000..0544e37 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/models/segmentation/__init__.py @@ -0,0 +1,149 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import importlib +from utils import logger +import argparse +from utils.download_utils import get_local_path +from utils.ddp_utils import is_master +from utils.common_utils import check_frozen_norm_layer + +from .. import register_tasks, register_task_arguments +from .base_seg import BaseSegmentation +from ...misc.common import load_pretrained_model +from ..classification import build_classification_model + +SEG_MODEL_REGISTRY = {} + + +def register_segmentation_models(name): + def register_model_class(cls): + if name in SEG_MODEL_REGISTRY: + raise ValueError("Cannot register duplicate model ({})".format(name)) + + if not issubclass(cls, BaseSegmentation): + raise ValueError( + "Model ({}: {}) must extend BaseSegmentation".format(name, cls.__name__) + ) + + SEG_MODEL_REGISTRY[name] = cls + return cls + + return register_model_class + + +@register_tasks(name="segmentation") +def build_segmentation_model(opts): + seg_model_name = getattr(opts, "model.segmentation.name", None) + model = None + is_master_node = is_master(opts) + if seg_model_name in SEG_MODEL_REGISTRY: + output_stride = getattr(opts, "model.segmentation.output_stride", None) + encoder = build_classification_model(opts=opts, output_stride=output_stride) + + seg_act_fn = getattr(opts, "model.segmentation.activation.name", None) + if seg_act_fn is not None: + # Override the general activation arguments + gen_act_fn = getattr(opts, "model.activation.name", "relu") + gen_act_inplace = getattr(opts, "model.activation.inplace", False) + gen_act_neg_slope = getattr(opts, "model.activation.neg_slope", 0.1) + + setattr(opts, "model.activation.name", seg_act_fn) + setattr( + opts, + "model.activation.inplace", + getattr(opts, "model.segmentation.activation.inplace", False), + ) + setattr( + opts, + "model.activation.neg_slope", + getattr(opts, "model.segmentation.activation.neg_slope", 0.1), + ) + + model = SEG_MODEL_REGISTRY[seg_model_name](opts, encoder) + + # Reset activation args + setattr(opts, "model.activation.name", gen_act_fn) + setattr(opts, "model.activation.inplace", gen_act_inplace) + setattr(opts, "model.activation.neg_slope", gen_act_neg_slope) + else: + model = SEG_MODEL_REGISTRY[seg_model_name](opts, encoder) + else: + supported_models = list(SEG_MODEL_REGISTRY.keys()) + if len(supported_models) == 0: + supported_models = ["none"] + supp_model_str = "Supported segmentation models are:" + for i, m_name in enumerate(supported_models): + supp_model_str += "\n\t {}: {}".format(i, logger.color_text(m_name)) + logger.error(supp_model_str) + + finetune_task = getattr(opts, "model.segmentation.finetune_pretrained_model", False) + pretrained = getattr(opts, "model.segmentation.pretrained", None) + if finetune_task: + n_pretrained_classes = getattr( + opts, "model.segmentation.n_pretrained_classes", None + ) + n_classes = getattr(opts, "model.segmentation.n_classes", None) + assert n_pretrained_classes is not None + assert n_classes is not None + + # The model structure is the same as pre-trained model now + model.update_classifier(opts, n_classes=n_pretrained_classes) + + # load the weights + if pretrained is not None: + pretrained = get_local_path(opts, path=pretrained) + model = load_pretrained_model(model=model, wt_loc=pretrained, opts=opts) + + # Now, re-initialize the classification layer + model.update_classifier(opts, n_classes=n_classes) + + elif pretrained is not None: + pretrained = get_local_path(opts, path=pretrained) + model = load_pretrained_model(model=model, wt_loc=pretrained, opts=opts) + + freeze_norm_layers = getattr(opts, "model.segmentation.freeze_batch_norm", False) + if freeze_norm_layers: + model.freeze_norm_layers() + frozen_state, count_norm = check_frozen_norm_layer(model) + if count_norm > 0 and frozen_state and is_master_node: + logger.error( + "Something is wrong while freezing normalization layers. Please check" + ) + + if is_master_node: + logger.log("Normalization layers are frozen") + + return model + + +@register_task_arguments(name="segmentation") +def arguments_segmentation(parser: argparse.ArgumentParser): + parser = BaseSegmentation.add_arguments(parser) + + # add segmentation specific arguments + for k, v in SEG_MODEL_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + + from affnet.models.segmentation.heads import arguments_segmentation_head + + parser = arguments_segmentation_head(parser) + + return parser + + +# automatically import the models +models_dir = os.path.dirname(__file__) +for file in os.listdir(models_dir): + path = os.path.join(models_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("affnet.models.segmentation." + model_name) diff --git a/Adaptive Frequency Filters/affnet/models/segmentation/base_seg.py b/Adaptive Frequency Filters/affnet/models/segmentation/base_seg.py new file mode 100644 index 0000000..4c3cb28 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/models/segmentation/base_seg.py @@ -0,0 +1,184 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +import argparse +from typing import Optional, Tuple, Dict + +from ..classification import BaseEncoder +from ... import parameter_list +from ...layers import norm_layers_tuple +from ...misc.init_utils import initialize_weights + + +class BaseSegmentation(nn.Module): + """Base class for segmentation networks""" + + def __init__(self, opts, encoder: BaseEncoder, *args, **kwargs) -> None: + super().__init__() + self.lr_multiplier = getattr(opts, "model.segmentation.lr_multiplier", 1.0) + assert isinstance( + encoder, BaseEncoder + ), "encoder should be an instance of BaseEncoder" + self.encoder: BaseEncoder = encoder + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + """Add segmentation model specific arguments""" + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--model.segmentation.name", type=str, default=None, help="Model name" + ) + group.add_argument( + "--model.segmentation.n-classes", + type=int, + default=20, + help="Number of classes in the dataset", + ) + group.add_argument( + "--model.segmentation.pretrained", + type=str, + default=None, + help="Path of the pretrained segmentation model. Useful for evaluation", + ) + group.add_argument( + "--model.segmentation.lr-multiplier", + type=float, + default=1.0, + help="Multiply the learning rate in segmentation network (e.g., decoder)", + ) + group.add_argument( + "--model.segmentation.classifier-dropout", + type=float, + default=0.1, + help="Dropout rate in classifier", + ) + group.add_argument( + "--model.segmentation.use-aux-head", + action="store_true", + help="Use auxiliary output", + ) + group.add_argument( + "--model.segmentation.aux-dropout", + default=0.1, + type=float, + help="Dropout in auxiliary branch", + ) + + group.add_argument( + "--model.segmentation.output-stride", + type=int, + default=None, + help="Output stride in classification network", + ) + group.add_argument( + "--model.segmentation.replace-stride-with-dilation", + action="store_true", + help="Replace stride with dilation", + ) + + group.add_argument( + "--model.segmentation.activation.name", + default=None, + type=str, + help="Non-linear function type", + ) + group.add_argument( + "--model.segmentation.activation.inplace", + action="store_true", + help="Inplace non-linear functions", + ) + group.add_argument( + "--model.segmentation.activation.neg-slope", + default=0.1, + type=float, + help="Negative slope in leaky relu", + ) + group.add_argument( + "--model.segmentation.freeze-batch-norm", + action="store_true", + help="Freeze batch norm layers", + ) + + group.add_argument( + "--model.segmentation.use-level5-exp", + action="store_true", + help="Use output of Level 5 expansion layer in base feature extractor", + ) + + group.add_argument( + "--model.segmentation.finetune-pretrained-model", + action="store_true", + help="Finetune a pretrained model", + ) + group.add_argument( + "--model.segmentation.n-pretrained-classes", + type=int, + default=None, + help="Number of pre-trained classes", + ) + return parser + + @staticmethod + def reset_layer_parameters(layer, opts) -> None: + """Reset weights of a given layer""" + initialize_weights(opts=opts, modules=layer.modules()) + + def get_trainable_parameters( + self, + weight_decay: Optional[float] = 0.0, + no_decay_bn_filter_bias: Optional[bool] = False, + *args, + **kwargs + ): + param_list = parameter_list( + named_parameters=self.named_parameters, + weight_decay=weight_decay, + no_decay_bn_filter_bias=no_decay_bn_filter_bias, + *args, + **kwargs + ) + return param_list, [1.0] * len(param_list) + + def profile_model(self, input: Tensor) -> Optional[Tuple[Tensor, float, float]]: + """ + Child classes must implement this function to compute FLOPs and parameters + """ + raise NotImplementedError + + def freeze_norm_layers(self) -> None: + for m in self.modules(): + if isinstance(m, norm_layers_tuple): + m.eval() + m.weight.requires_grad = False + m.bias.requires_grad = False + m.training = False + + def dummy_input_and_label(self, batch_size: int) -> Dict: + """Create dummy input and labels for CI/CD purposes. Child classes must override it + if functionality is different. + """ + img_channels = 3 + height = 224 + width = 224 + n_classes = 10 + img_tensor = torch.randn( + batch_size, img_channels, height, width, dtype=torch.float + ) + label_tensor = torch.randint( + low=0, high=n_classes, size=(batch_size, height, width) + ).long() + return {"samples": img_tensor, "targets": label_tensor} + + def update_classifier(self, opts, n_classes: int) -> None: + """ + This function updates the classification layer in a model. Useful for finetuning purposes. + """ + raise NotImplementedError diff --git a/Adaptive Frequency Filters/affnet/models/segmentation/enc_dec.py b/Adaptive Frequency Filters/affnet/models/segmentation/enc_dec.py new file mode 100644 index 0000000..7a6357f --- /dev/null +++ b/Adaptive Frequency Filters/affnet/models/segmentation/enc_dec.py @@ -0,0 +1,198 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import Tensor +from typing import Union, Dict, Tuple, Optional +import argparse + +from utils import logger + +from . import BaseSegmentation, register_segmentation_models +from .heads import build_segmentation_head +from ..classification import BaseEncoder + + +@register_segmentation_models(name="encoder_decoder") +class SegEncoderDecoder(BaseSegmentation): + """ + This class defines a encoder-decoder architecture for the task of semantic segmentation. Different segmentation + heads (e.g., PSPNet and DeepLabv3) can be used + + Args: + opts: command-line arguments + encoder (BaseEncoder): Backbone network (e.g., MobileViT or ResNet) + """ + + def __init__(self, opts, encoder: BaseEncoder, *args, **kwargs) -> None: + super().__init__(opts=opts, encoder=encoder) + + # delete layers that are not required in segmentation network + self.encoder.classifier = None + use_l5_exp = getattr(opts, "model.segmentation.use_level5_exp", False) + if not use_l5_exp: + self.encoder.conv_1x1_exp = None + + self.seg_head = build_segmentation_head( + opts=opts, enc_conf=self.encoder.model_conf_dict, use_l5_exp=use_l5_exp + ) + self.use_l5_exp = use_l5_exp + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + return parser + + def get_trainable_parameters( + self, + weight_decay: Optional[float] = 0.0, + no_decay_bn_filter_bias: Optional[bool] = False, + *args, + **kwargs + ): + """This function separates the parameters for backbone and segmentation head, so that + different learning rates can be used for backbone and segmentation head + """ + encoder_params, enc_lr_mult = self.encoder.get_trainable_parameters( + weight_decay=weight_decay, + no_decay_bn_filter_bias=no_decay_bn_filter_bias, + module_name="encoder.", + *args, + **kwargs + ) + decoder_params, dec_lr_mult = self.seg_head.get_trainable_parameters( + weight_decay=weight_decay, + no_decay_bn_filter_bias=no_decay_bn_filter_bias, + module_name="seg_head.", + *args, + **kwargs + ) + + total_params = sum([p.numel() for p in self.parameters()]) + encoder_params_count = sum([p.numel() for p in self.encoder.parameters()]) + decoder_params_count = sum([p.numel() for p in self.seg_head.parameters()]) + + assert total_params == encoder_params_count + decoder_params_count, ( + "Total network parameters are not equal to " + "the sum of encoder and decoder. " + "{} != {} + {}".format( + total_params, encoder_params_count, decoder_params_count + ) + ) + + return encoder_params + decoder_params, enc_lr_mult + dec_lr_mult + + def forward( + self, x: Tensor, *args, **kwargs + ) -> Union[Tuple[Tensor, Tensor], Tensor, Dict]: + enc_end_points: Dict = self.encoder.extract_end_points_all( + x, use_l5=True, use_l5_exp=self.use_l5_exp + ) + + if "augmented_tensor" in enc_end_points: + output_dict = { + "augmented_tensor": enc_end_points.pop("augmented_tensor"), + "segmentation_output": self.seg_head( + enc_out=enc_end_points, *args, **kwargs + ), + } + return output_dict + else: + return self.seg_head(enc_out=enc_end_points, *args, **kwargs) + + def update_classifier(self, opts, n_classes: int) -> None: + """ + This function updates the classification layer in a model. Useful for finetuning purposes. + """ + if hasattr(self.seg_head, "update_classifier"): + self.seg_head.update_classifier(opts, n_classes) + + def profile_model(self, input: Tensor) -> None: + """ + This function computes layer-wise FLOPs and parameters for segmentation network + + .. note:: + Model profiling is for reference only and may contain errors as it relies heavily on user + to implement the underlying functions accurately. + """ + + overall_params, overall_macs = 0.0, 0.0 + input_fvcore = input.clone() + + logger.log("Model statistics for an input of size {}".format(input.size())) + logger.double_dash_line(dashes=65) + print("{:>35} Summary".format(self.__class__.__name__)) + logger.double_dash_line(dashes=65) + + # profile encoder + enc_str = ( + logger.text_colors["logs"] + + logger.text_colors["bold"] + + "Encoder " + + logger.text_colors["end_color"] + ) + print("{:>45}".format(enc_str)) + enc_end_points, encoder_params, encoder_macs = self.encoder.profile_model( + input, is_classification=False + ) + overall_params += encoder_params + overall_macs += encoder_macs + + # profile decoder + dec_str = ( + logger.text_colors["logs"] + + logger.text_colors["bold"] + + "Decoder " + + logger.text_colors["end_color"] + ) + print("{:>45}".format(dec_str)) + + out, decoder_params, decoder_macs = self.seg_head.profile_module(enc_end_points) + overall_params += decoder_params + overall_macs += decoder_macs + + logger.double_dash_line(dashes=65) + print("{:<20} = {:>8.3f} M".format("Overall parameters", overall_params / 1e6)) + overall_params_py = sum([p.numel() for p in self.parameters()]) + print( + "{:<20} = {:>8.3f} M".format( + "Overall parameters (sanity check)", overall_params_py / 1e6 + ) + ) + + # Counting Addition and Multiplication as 1 operation + print( + "{:<20} = {:>8.3f} M".format( + "Overall MACs (theoretical)", overall_macs / 1e6 + ) + ) + + # compute flops using FVCore + try: + # compute flops using FVCore also + from fvcore.nn import FlopCountAnalysis + + flop_analyzer = FlopCountAnalysis(self.eval(), input_fvcore) + flop_analyzer.unsupported_ops_warnings(False) + flop_analyzer.uncalled_modules_warnings(False) + flops_fvcore = flop_analyzer.total() + print( + "{:<20} = {:>8.3f} M".format( + "Overall MACs (FVCore)**", flops_fvcore / 1e6 + ) + ) + print( + "\n** Theoretical and FVCore MACs may vary as theoretical MACs do not account " + "for certain operations which may or may not be accounted in FVCore" + ) + except ModuleNotFoundError as mnfe: + logger.warning( + "Please install fvcore to profile {} model".format( + self.__class__.__name__ + ) + ) + except Exception: + pass + + logger.double_dash_line(dashes=65) diff --git a/Adaptive Frequency Filters/affnet/models/segmentation/heads/__init__.py b/Adaptive Frequency Filters/affnet/models/segmentation/heads/__init__.py new file mode 100644 index 0000000..f730c5c --- /dev/null +++ b/Adaptive Frequency Filters/affnet/models/segmentation/heads/__init__.py @@ -0,0 +1,72 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import importlib +from utils import logger +from typing import Dict +import argparse + +from .base_seg_head import BaseSegHead + +SEG_HEAD_REGISTRY = {} + + +def register_segmentation_head(name): + def register_model_class(cls): + if name in SEG_HEAD_REGISTRY: + raise ValueError("Cannot register duplicate model ({})".format(name)) + + if not issubclass(cls, BaseSegHead): + raise ValueError( + "Model ({}: {}) must extend BaseSegHead".format(name, cls.__name__) + ) + + SEG_HEAD_REGISTRY[name] = cls + return cls + + return register_model_class + + +def build_segmentation_head(opts, enc_conf: Dict, use_l5_exp: bool = False): + seg_model_name = getattr(opts, "model.segmentation.seg_head", "lr_aspp") + seg_head = None + if seg_model_name in SEG_HEAD_REGISTRY: + seg_head = SEG_HEAD_REGISTRY[seg_model_name]( + opts=opts, enc_conf=enc_conf, use_l5_exp=use_l5_exp + ) + else: + supported_heads = list(SEG_HEAD_REGISTRY.keys()) + supp_model_str = "Supported segmentation heads are:" + for i, m_name in enumerate(supported_heads): + supp_model_str += "\n\t {}: {}".format(i, logger.color_text(m_name)) + logger.error(supp_model_str) + + return seg_head + + +def arguments_segmentation_head(parser: argparse.ArgumentParser): + # add segmentation head specific arguments + parser = BaseSegHead.add_arguments(parser=parser) + for k, v in SEG_HEAD_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + + return parser + + +# automatically import the models +models_dir = os.path.dirname(__file__) +for file in os.listdir(models_dir): + path = os.path.join(models_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module( + "affnet.models.segmentation.heads." + model_name + ) diff --git a/Adaptive Frequency Filters/affnet/models/segmentation/heads/base_seg_head.py b/Adaptive Frequency Filters/affnet/models/segmentation/heads/base_seg_head.py new file mode 100644 index 0000000..55165ab --- /dev/null +++ b/Adaptive Frequency Filters/affnet/models/segmentation/heads/base_seg_head.py @@ -0,0 +1,175 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +from typing import Optional, Dict, Tuple +import argparse + +from utils import logger + +from ....misc.common import parameter_list +from ....misc.init_utils import initialize_weights +from ....layers import ConvLayer, Dropout2d, UpSample + + +class BaseSegHead(nn.Module): + """ + Base class for segmentation heads + """ + + def __init__(self, opts, enc_conf: dict, use_l5_exp: Optional[bool] = False): + enc_ch_l5_exp_out = _check_out_channels(enc_conf, "exp_before_cls") + enc_ch_l5_out = _check_out_channels(enc_conf, "layer5") + enc_ch_l4_out = _check_out_channels(enc_conf, "layer4") + enc_ch_l3_out = _check_out_channels(enc_conf, "layer3") + enc_ch_l2_out = _check_out_channels(enc_conf, "layer2") + enc_ch_l1_out = _check_out_channels(enc_conf, "layer1") + + super().__init__() + + self.use_l5_exp = use_l5_exp + self.enc_l5_exp_channels = enc_ch_l5_exp_out + self.enc_l5_channels = enc_ch_l5_out + self.enc_l4_channels = enc_ch_l4_out + self.enc_l3_channels = enc_ch_l3_out + self.enc_l2_channels = enc_ch_l2_out + self.enc_l1_channels = enc_ch_l1_out + + self.n_seg_classes = getattr(opts, "model.segmentation.n_classes", 20) + self.lr_multiplier = getattr(opts, "model.segmentation.lr_multiplier", 1.0) + self.classifier_dropout = getattr( + opts, "model.segmentation.classifier_dropout", 0.1 + ) + self.output_stride = getattr(opts, "model.segmentation.output_stride", 16) + + self.aux_head = None + if getattr(opts, "model.segmentation.use_aux_head", False): + drop_aux = getattr(opts, "model.segmentation.aux_dropout", 0.1) + inner_channels = max(int(self.enc_l4_channels // 4), 128) + self.aux_head = nn.Sequential( + ConvLayer( + opts=opts, + in_channels=self.enc_l4_channels, + out_channels=inner_channels, + kernel_size=3, + stride=1, + use_norm=True, + use_act=True, + bias=False, + groups=1, + ), + Dropout2d(drop_aux), + ConvLayer( + opts=opts, + in_channels=inner_channels, + out_channels=self.n_seg_classes, + kernel_size=1, + stride=1, + use_norm=False, + use_act=False, + bias=True, + groups=1, + ), + ) + + self.upsample_seg_out = None + if self.output_stride != 1.0: + self.upsample_seg_out = UpSample( + scale_factor=self.output_stride, mode="bilinear", align_corners=True + ) + + def forward_aux_head(self, enc_out: Dict) -> Tensor: + aux_out = self.aux_head(enc_out["out_l4"]) + return aux_out + + def forward_seg_head(self, enc_out: Dict) -> Tensor: + raise NotImplementedError + + def forward(self, enc_out: Dict, *args, **kwargs) -> Tensor or Tuple[Tensor]: + out = self.forward_seg_head(enc_out=enc_out) + + if self.upsample_seg_out is not None: + # resize the mask based on given size + mask_size = kwargs.get("orig_size", None) + if mask_size is not None: + self.upsample_seg_out.scale_factor = None + self.upsample_seg_out.size = mask_size + + out = self.upsample_seg_out(out) + + if self.aux_head is not None and self.training: + aux_out = self.forward_aux_head(enc_out=enc_out) + return out, aux_out + return out + + def reset_head_parameters(self, opts) -> None: + # weight initialization + initialize_weights(opts=opts, modules=self.modules()) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Add segmentation head specific arguments""" + group = parser.add_argument_group( + title="Segmentation head arguments", + description="Segmentation head arguments", + ) + group.add_argument( + "--model.segmentation.seg-head", + type=str, + default=None, + help="Segmentation head", + ) + + return parser + + def profile_module(self, x: Tensor) -> Tuple[Tensor, float, float]: + """ + Child classes must implement this function to compute FLOPs and parameters + """ + raise NotImplementedError + + def get_trainable_parameters( + self, + weight_decay: float = 0.0, + no_decay_bn_filter_bias: bool = False, + *args, + **kwargs + ): + param_list = parameter_list( + named_parameters=self.named_parameters, + weight_decay=weight_decay, + no_decay_bn_filter_bias=no_decay_bn_filter_bias, + *args, + **kwargs + ) + return param_list, [self.lr_multiplier] * len(param_list) + + def update_classifier(self, opts, n_classes: int) -> None: + """ + This function updates the classification layer in a model. Useful for finetuning purposes. + """ + raise NotImplementedError + + +def _check_out_channels(config: dict, layer_name: str) -> int: + enc_ch_l: dict = config.get(layer_name, None) + if enc_ch_l is None or not enc_ch_l: + logger.error( + "Encoder does not define input-output mapping for {}: Got: {}".format( + layer_name, config + ) + ) + + enc_ch_l_out = enc_ch_l.get("out", None) + if enc_ch_l_out is None or not enc_ch_l_out: + logger.error( + "Output channels are not defined in {} of the encoder. Got: {}".format( + layer_name, enc_ch_l + ) + ) + + return enc_ch_l_out diff --git a/Adaptive Frequency Filters/affnet/models/segmentation/heads/deeplabv3.py b/Adaptive Frequency Filters/affnet/models/segmentation/heads/deeplabv3.py new file mode 100644 index 0000000..3b42e0c --- /dev/null +++ b/Adaptive Frequency Filters/affnet/models/segmentation/heads/deeplabv3.py @@ -0,0 +1,156 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +import argparse +from typing import Optional, Dict, Tuple + +from .base_seg_head import BaseSegHead +from . import register_segmentation_head +from ....layers import ConvLayer +from ....modules import ASPP +from ....misc.profiler import module_profile +from ....misc.init_utils import initialize_weights + + +@register_segmentation_head(name="deeplabv3") +class DeeplabV3(BaseSegHead): + """ + This class defines the segmentation head in `DeepLabv3 architecture `_ + Args: + opts: command-line arguments + enc_conf (Dict): Encoder input-output configuration at each spatial level + use_l5_exp (Optional[bool]): Use features from expansion layer in Level5 in the encoder + """ + + def __init__( + self, opts, enc_conf: Dict, use_l5_exp: Optional[bool] = False, *args, **kwargs + ) -> None: + atrous_rates = getattr( + opts, "model.segmentation.deeplabv3.aspp_rates", (6, 12, 18) + ) + out_channels = getattr( + opts, "model.segmentation.deeplabv3.aspp_out_channels", 256 + ) + is_sep_conv = getattr(opts, "model.segmentation.deeplabv3.aspp_sep_conv", False) + dropout = getattr(opts, "model.segmentation.deeplabv3.aspp_dropout", 0.1) + + super().__init__(opts=opts, enc_conf=enc_conf, use_l5_exp=use_l5_exp) + + self.aspp = nn.Sequential() + aspp_in_channels = ( + self.enc_l5_channels if not self.use_l5_exp else self.enc_l5_exp_channels + ) + self.aspp.add_module( + name="aspp_layer", + module=ASPP( + opts=opts, + in_channels=aspp_in_channels, + out_channels=out_channels, + atrous_rates=atrous_rates, + is_sep_conv=is_sep_conv, + dropout=dropout, + ), + ) + + self.classifier = ConvLayer( + opts=opts, + in_channels=out_channels, + out_channels=self.n_seg_classes, + kernel_size=1, + stride=1, + use_norm=False, + use_act=False, + bias=True, + ) + + self.reset_head_parameters(opts=opts) + + def update_classifier(self, opts, n_classes: int) -> None: + """ + This function updates the classification layer in a model. Useful for finetuning purposes. + """ + in_channels = self.classifier.in_channels + conv_layer = ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=n_classes, + kernel_size=1, + stride=1, + use_norm=False, + use_act=False, + bias=True, + ) + initialize_weights(opts, modules=conv_layer) + self.classifier = conv_layer + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """DeepLabv3 specific arguments""" + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--model.segmentation.deeplabv3.aspp-rates", + type=tuple, + default=(6, 12, 18), + help="Atrous rates in DeepLabV3+ model", + ) + group.add_argument( + "--model.segmentation.deeplabv3.aspp-out-channels", + type=int, + default=256, + help="Output channels of ASPP module", + ) + group.add_argument( + "--model.segmentation.deeplabv3.aspp-sep-conv", + action="store_true", + help="Separable conv in ASPP module", + ) + group.add_argument( + "--model.segmentation.deeplabv3.aspp-dropout", + type=float, + default=0.1, + help="Dropout in ASPP module", + ) + return parser + + def forward_seg_head(self, enc_out: Dict) -> Tensor: + # low resolution features + x = enc_out["out_l5_exp"] if self.use_l5_exp else enc_out["out_l5"] + # ASPP featues + x = self.aspp(x) + # classify + x = self.classifier(x) + return x + + def profile_module(self, enc_out: Dict) -> Tuple[Tensor, float, float]: + # Note: Model profiling is for reference only and may contain errors. + # It relies heavily on the user to implement the underlying functions accurately. + + params, macs = 0.0, 0.0 + + if self.use_l5_exp: + x, p, m = module_profile(module=self.aspp, x=enc_out["out_l5_exp"]) + else: + x, p, m = module_profile(module=self.aspp, x=enc_out["out_l5"]) + params += p + macs += m + + out, p, m = module_profile(module=self.classifier, x=x) + params += p + macs += m + + print( + "{:<15} \t {:<5}: {:>8.3f} M \t {:<5}: {:>8.3f} M".format( + self.__class__.__name__, + "Params", + round(params / 1e6, 3), + "MACs", + round(macs / 1e6, 3), + ) + ) + return out, params, macs diff --git a/Adaptive Frequency Filters/affnet/models/segmentation/heads/pspnet.py b/Adaptive Frequency Filters/affnet/models/segmentation/heads/pspnet.py new file mode 100644 index 0000000..d0d8899 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/models/segmentation/heads/pspnet.py @@ -0,0 +1,147 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +import argparse +from typing import Optional, Dict, Tuple + +from .base_seg_head import BaseSegHead +from . import register_segmentation_head +from ....layers import ConvLayer +from ....modules import PSP +from ....misc.profiler import module_profile +from ....misc.init_utils import initialize_weights + + +@register_segmentation_head(name="pspnet") +class PSPNet(BaseSegHead): + """ + This class defines the segmentation head in `PSPNet architecture `_ + Args: + opts: command-line arguments + enc_conf (Dict): Encoder input-output configuration at each spatial level + use_l5_exp (Optional[bool]): Use features from expansion layer in Level5 in the encoder + """ + + def __init__( + self, opts, enc_conf: dict, use_l5_exp: Optional[bool] = False, *args, **kwargs + ) -> None: + psp_out_channels = getattr( + opts, "model.segmentation.pspnet.psp_out_channels", 512 + ) + psp_pool_sizes = getattr( + opts, "model.segmentation.pspnet.psp_pool_sizes", [1, 2, 3, 6] + ) + psp_dropout = getattr(opts, "model.segmentation.pspnet.psp_dropout", 0.1) + + super().__init__(opts=opts, enc_conf=enc_conf, use_l5_exp=use_l5_exp) + + psp_in_channels = ( + self.enc_l5_channels if not self.use_l5_exp else self.enc_l5_exp_channels + ) + self.psp_layer = PSP( + opts=opts, + in_channels=psp_in_channels, + out_channels=psp_out_channels, + pool_sizes=psp_pool_sizes, + dropout=psp_dropout, + ) + self.classifier = ConvLayer( + opts=opts, + in_channels=psp_out_channels, + out_channels=self.n_seg_classes, + kernel_size=1, + stride=1, + use_norm=False, + use_act=False, + bias=True, + ) + self.reset_head_parameters(opts=opts) + + def update_classifier(self, opts, n_classes: int) -> None: + """ + This function updates the classification layer in a model. Useful for finetuning purposes. + """ + in_channels = self.classifier.in_channels + conv_layer = ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=n_classes, + kernel_size=1, + stride=1, + use_norm=False, + use_act=False, + bias=True, + ) + + initialize_weights(opts, modules=conv_layer) + self.classifier = conv_layer + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--model.segmentation.pspnet.psp-pool-sizes", + type=int, + nargs="+", + default=[1, 2, 3, 6], + help="Pool sizes in the PSPNet module", + ) + group.add_argument( + "--model.segmentation.pspnet.psp-out-channels", + type=int, + default=512, + help="Output channels of PSPNet module", + ) + group.add_argument( + "--model.segmentation.pspnet.psp-dropout", + type=float, + default=0.1, + help="Dropout in the PSPNet module", + ) + return parser + + def forward_seg_head(self, enc_out: Dict) -> Tensor: + # low resolution features + x = enc_out["out_l5_exp"] if self.use_l5_exp else enc_out["out_l5"] + + # Apply PSP layer + x = self.psp_layer(x) + + out = self.classifier(x) + + return out + + def profile_module(self, enc_out: Dict) -> Tuple[Tensor, float, float]: + # Note: Model profiling is for reference only and may contain errors. + # It relies heavily on the user to implement the underlying functions accurately. + + params, macs = 0.0, 0.0 + + if self.use_l5_exp: + x, p, m = module_profile(module=self.psp_layer, x=enc_out["out_l5_exp"]) + else: + x, p, m = module_profile(module=self.psp_layer, x=enc_out["out_l5"]) + params += p + macs += m + + out, p, m = module_profile(module=self.classifier, x=x) + params += p + macs += m + + print( + "{:<15} \t {:<5}: {:>8.3f} M \t {:<5}: {:>8.3f} M".format( + self.__class__.__name__, + "Params", + round(params / 1e6, 3), + "MACs", + round(macs / 1e6, 3), + ) + ) + return out, params, macs diff --git a/Adaptive Frequency Filters/affnet/modules/__init__.py b/Adaptive Frequency Filters/affnet/modules/__init__.py new file mode 100644 index 0000000..31f127c --- /dev/null +++ b/Adaptive Frequency Filters/affnet/modules/__init__.py @@ -0,0 +1,27 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from .base_module import BaseModule +from .squeeze_excitation import SqueezeExcitation +from .mobilenetv2 import InvertedResidual, InvertedResidualSE +from .aspp_block import ASPP +from .transformer import TransformerEncoder +from .pspnet_module import PSP +from .feature_pyramid import FeaturePyramidNetwork +from .ssd_heads import SSDHead, SSDInstanceHead + + +__all__ = [ + "InvertedResidual", + "InvertedResidualSE", + "ASPP", + "TransformerEncoder", + "SqueezeExcitation", + "PSP", + "FeaturePyramidNetwork", + "SSDHead", + "SSDInstanceHead", +] diff --git a/Adaptive Frequency Filters/affnet/modules/aff_block.py b/Adaptive Frequency Filters/affnet/modules/aff_block.py new file mode 100644 index 0000000..96bb1de --- /dev/null +++ b/Adaptive Frequency Filters/affnet/modules/aff_block.py @@ -0,0 +1,541 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- +import einops +import numpy as np +from torch import nn, Tensor +import math +import torch +from torch.nn import functional as F +from typing import Optional, Dict, Tuple, Union, Sequence + +from . import InvertedResidual +from .transformer import TransformerEncoder, LinearAttnFFN +from .base_module import BaseModule +from ..misc.profiler import module_profile +from ..layers import ConvLayer, get_normalization_layer, get_activation_fn +import typing +from typing import Any, List +from einops.layers.torch import Rearrange +import math +import torch +import torch.fft +import torch.nn as nn +import torch.nn.functional as F +import time + + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob, 3):0.3f}' + + +class AFNO2D_channelfirst(nn.Module): + """ + hidden_size: channel dimension size + num_blocks: how many blocks to use in the block diagonal weight matrices (higher => less complexity but less parameters) + sparsity_threshold: lambda for softshrink + hard_thresholding_fraction: how many frequencies you want to completely mask out (lower => hard_thresholding_fraction^2 less FLOPs) + input shape [B N C] + """ + + def __init__(self, opts, hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, + hidden_size_factor=1): + super().__init__() + assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}" + + self.hidden_size = hidden_size + self.sparsity_threshold = getattr(opts, "model.activation.sparsity_threshold", 0.01) + self.num_blocks = num_blocks + self.block_size = self.hidden_size // self.num_blocks + self.hard_thresholding_fraction = hard_thresholding_fraction + self.hidden_size_factor = hidden_size_factor + self.scale = 0.02 + + self.w1 = nn.Parameter( + self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor)) + self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor)) + self.w2 = nn.Parameter( + self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size)) + self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) + # self.norm_layer1 = get_normalization_layer(opts=opts, num_features=out_channels) + self.act = self.build_act_layer(opts=opts) + self.act2 = self.build_act_layer(opts=opts) + + @staticmethod + def build_act_layer(opts) -> nn.Module: + act_type = getattr(opts, "model.activation.name", "relu") + neg_slope = getattr(opts, "model.activation.neg_slope", 0.1) + inplace = getattr(opts, "model.activation.inplace", False) + act_layer = get_activation_fn( + act_type=act_type, + inplace=inplace, + negative_slope=neg_slope, + num_parameters=1, + ) + return act_layer + + @torch.cuda.amp.autocast(enabled=False) + def forward(self, x, spatial_size=None): + bias = x + + dtype = x.dtype + x = x.float() + B, C, H, W = x.shape + # x = self.fu(x) + + x = torch.fft.rfft2(x, dim=(2, 3), norm="ortho") + origin_ffted = x + x = x.reshape(B, self.num_blocks, self.block_size, x.shape[2], x.shape[3]) + + + o1_real = self.act( + torch.einsum('bkihw,kio->bkohw', x.real, self.w1[0]) - \ + torch.einsum('bkihw,kio->bkohw', x.imag, self.w1[1]) + \ + self.b1[0, :, :, None, None] + ) + + o1_imag = self.act2( + torch.einsum('bkihw,kio->bkohw', x.imag, self.w1[0]) + \ + torch.einsum('bkihw,kio->bkohw', x.real, self.w1[1]) + \ + self.b1[1, :, :, None, None] + ) + + o2_real = ( + torch.einsum('bkihw,kio->bkohw', o1_real, self.w2[0]) - \ + torch.einsum('bkihw,kio->bkohw', o1_imag, self.w2[1]) + \ + self.b2[0, :, :, None, None] + ) + + o2_imag = ( + torch.einsum('bkihw,kio->bkohw', o1_imag, self.w2[0]) + \ + torch.einsum('bkihw,kio->bkohw', o1_real, self.w2[1]) + \ + self.b2[1, :, :, None, None] + ) + + x = torch.stack([o2_real, o2_imag], dim=-1) + x = F.softshrink(x, lambd=self.sparsity_threshold) + x = torch.view_as_complex(x) + x = x.reshape(B, C, x.shape[3], x.shape[4]) + + x = x * origin_ffted + x = torch.fft.irfft2(x, s=(H, W), dim=(2, 3), norm="ortho") + x = x.type(dtype) + + return x + bias + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + # TODO: to edit it + b_sz, c, h, w = input.shape + seq_len = h * w + + # FFT iFFT + p_ff, m_ff = 0, 5 * b_sz * seq_len * int(math.log(seq_len)) * c + # others + # params = macs = sum([p.numel() for p in self.parameters()]) + params = macs = self.hidden_size * self.hidden_size_factor * self.hidden_size * 2 * 2 // self.num_blocks + # // 2 min n become half after fft + macs = macs * b_sz * seq_len + + # return input, params, macs + return input, params, macs + m_ff + + +def remove_edge(img: np.ndarray): + # // remove the edge of a numpy image + return img[1:-1, 1:-1] + +def save_feature(feature): + import time + import matplotlib.pyplot as plt + import os + now = time.time() + feature = feature.detach() + os.makedirs('visual_example', exist_ok=True) + for i in range(feature.shape[1]): + feature_channel = feature[0, i] + fig, ax = plt.subplots() + img_channel = ax.imshow(remove_edge(feature_channel.cpu().numpy()), cmap='gray') + plt.savefig('visual_example/{now}_channel_{i}_feature.png'.format(now=str(now), i=i)) + for i in range(8): + feature_group = torch.mean(feature[0, i * 8:(i + 1) * 8], dim=1) + fig, ax = plt.subplots() + img_group = ax.imshow(remove_edge(feature_group.cpu().numpy()), cmap='gray') + plt.savefig('visual_example/{now}_group_{i}_feature.png'.format(now=str(now), i=i)) + +def save_kernel(origin_ffted, H, W): + import time + import matplotlib.pyplot as plt + import os + now = time.time() + origin_ffted = origin_ffted.detach() + kernel = torch.fft.irfft2(origin_ffted, s=(H, W), dim=(2, 3), norm="ortho") + group_channels = kernel.shape[1] // 8 + os.makedirs('visual_example', exist_ok=True) + for i in range(kernel.shape[1]): + kernel_channel = kernel[0, i] + fig, ax = plt.subplots() + img_channel = ax.imshow(remove_edge(kernel_channel.cpu().numpy()), cmap='gray') + plt.savefig('visual_example/{now}_channel_{i}_kernel.png'.format(now=str(now), i=i)) + for i in range(8): + kernel_group = torch.mean(kernel[0, i*group_channels: (i+1)*group_channels], dim=0) + fig, ax = plt.subplots() + img_group = ax.imshow(remove_edge(kernel_group.cpu().numpy()), cmap='gray') + plt.savefig('visual_example/{now}_group_{i}_kernel.png'.format(now=str(now), i=i)) + kernel_mean = torch.mean(kernel[0], dim=0) + fig, ax = plt.subplots() + img_mean = ax.imshow(remove_edge(kernel_mean.cpu().numpy()), cmap='gray') + plt.savefig('visual_example/{now}_all_kernel.png'.format(now=str(now))) + + abs = origin_ffted.abs() + abs_group_channels = abs.shape[1] // 8 + os.makedirs('visual_mask_example', exist_ok=True) + for i in range(abs.shape[1]): + abs_channel = abs[0, i] + fig, ax = plt.subplots() + abs_channel = ax.imshow(abs_channel.cpu().numpy(), cmap='gray') + plt.savefig('visual_mask_example/{now}_channel_{i}_abs.png'.format(now=str(now), i=i)) + for i in range(8): + abs_group = torch.mean(abs[0, i*abs_group_channels: (i+1)*abs_group_channels], dim=0) + fig, ax = plt.subplots() + img_group = ax.imshow(abs_group.cpu().numpy(), cmap='gray') + plt.savefig('visual_mask_example/{now}_group_{i}_abs.png'.format(now=str(now), i=i)) + abs_mean = torch.mean(abs[0], dim=0) + fig, ax = plt.subplots() + img_mean = ax.imshow(abs_mean.cpu().numpy(), cmap='gray') + plt.savefig('visual_mask_example/{now}_all_abs.png'.format(now=str(now))) + + real = origin_ffted.real + real_group_channels = real.shape[1] // 8 + os.makedirs('visual_mask_example', exist_ok=True) + for i in range(real.shape[1]): + real_channel = real[0, i] + fig, ax = plt.subplots() + real_channel = ax.imshow(real_channel.cpu().numpy(), cmap='gray') + plt.savefig('visual_mask_example/{now}_channel_{i}_real.png'.format(now=str(now), i=i)) + for i in range(8): + real_group = torch.mean(real[0, i*real_group_channels: (i+1)*real_group_channels], dim=0) + fig, ax = plt.subplots() + img_group = ax.imshow(real_group.cpu().numpy(), cmap='gray') + plt.savefig('visual_mask_example/{now}_group_{i}_mask.png'.format(now=str(now), i=i)) + real_mean = torch.mean(real[0], dim=0) + fig, ax = plt.subplots() + img_mean = ax.imshow(real_mean.cpu().numpy(), cmap='gray') + plt.savefig('visual_mask_example/{now}_all_real.png'.format(now=str(now))) + + imag = origin_ffted.imag + imag_group_channels = imag.shape[1] // 8 + os.makedirs('visual_mask_example', exist_ok=True) + for i in range(8): + imag_group = torch.mean(imag[0, i*imag_group_channels: (i+1)*imag_group_channels], dim=0) + fig, ax = plt.subplots() + img_group = ax.imshow(imag_group.cpu().numpy(), cmap='gray') + plt.savefig('visual_mask_example/{now}_group_{i}_imag.png'.format(now=str(now), i=i)) + imag_mean = torch.mean(imag[0], dim=0) + fig, ax = plt.subplots() + img_mean = ax.imshow(imag_mean.cpu().numpy(), cmap='gray') + plt.savefig('visual_mask_example/{now}_all_imag.png'.format(now=str(now))) + + + +class Block(nn.Module): + def __init__(self, opts, dim, hidden_size, num_blocks, double_skip, mlp_ratio=4., drop_path=0., attn_norm_layer='sync_batch_norm', enable_coreml_compatible_fn=False): + # input shape [B C H W] + super().__init__() + self.norm1 = get_normalization_layer( + opts=opts, norm_type=attn_norm_layer, num_features=dim + ) + self.filter = AFNO2D_channelfirst(opts=opts, hidden_size=hidden_size, num_blocks=num_blocks, sparsity_threshold=0.01, + hard_thresholding_fraction=1, hidden_size_factor=1) if not enable_coreml_compatible_fn else \ + AFNO2D_channelfirst_coreml(opts=opts, hidden_size=hidden_size, num_blocks=num_blocks, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = get_normalization_layer( + opts=opts, norm_type=attn_norm_layer, num_features=dim + ) + self.mlp = InvertedResidual( + opts=opts, + in_channels=dim, + out_channels=dim, + stride=1, + expand_ratio=mlp_ratio, + ) + self.double_skip = double_skip + + def forward(self, x): + residual = x + x = self.norm1(x) + # x = self.filter(x) + x = self.mlp(x) + + if self.double_skip: + x = x + residual + residual = x + + x = self.norm2(x) + # x = self.mlp(x) + x = self.filter(x) + x = self.drop_path(x) + x = x + residual + return x + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + b_sz, c, h, w = input.shape + seq_len = h * w + + out, p_ffn, m_ffn = module_profile(module=self.mlp, x=input) + # m_ffn = m_ffn * b_sz * seq_len + + out, p_mha, m_mha = module_profile(module=self.filter, x=out) + + + macs = m_mha + m_ffn + params = p_mha + p_ffn + + return input, params, macs + + + + +class AFFBlock(BaseModule): + + def __init__( + self, + opts, + in_channels: int, + transformer_dim: int, + ffn_dim: int, + n_transformer_blocks: Optional[int] = 2, + head_dim: Optional[int] = 32, + attn_dropout: Optional[float] = 0.0, + dropout: Optional[int] = 0.0, + ffn_dropout: Optional[int] = 0.0, + patch_h: Optional[int] = 8, + patch_w: Optional[int] = 8, + attn_norm_layer: Optional[str] = "layer_norm_2d", + conv_ksize: Optional[int] = 3, + dilation: Optional[int] = 1, + no_fusion: Optional[bool] = False, + *args, + **kwargs + ) -> None: + + conv_1x1_out = ConvLayer( + opts=opts, + in_channels=transformer_dim, + out_channels=in_channels, + kernel_size=1, + stride=1, + use_norm=True, + use_act=False, + ) + conv_3x3_out = None + if not no_fusion: + conv_3x3_out = ConvLayer( + opts=opts, + in_channels=2 * in_channels, + out_channels=in_channels, + kernel_size=1, # conv_ksize -> 1 + stride=1, + use_norm=True, + use_act=True, + ) + super().__init__() + + assert transformer_dim % head_dim == 0 + num_heads = transformer_dim // head_dim + self.enable_coreml_compatible_fn = getattr( + opts, "common.enable_coreml_compatible_module", False + ) or getattr(opts, "benchmark.use_jit_model", False) + print(self.enable_coreml_compatible_fn) + + global_rep = [ + # TODO: to check the double skip + Block( + opts=opts, + dim=transformer_dim, + hidden_size=transformer_dim, + num_blocks=8, + double_skip=False, + mlp_ratio=ffn_dim / transformer_dim, + attn_norm_layer=attn_norm_layer, + enable_coreml_compatible_fn=self.enable_coreml_compatible_fn + ) + for _ in range(n_transformer_blocks) + ] + global_rep.append( + get_normalization_layer( + opts=opts, + norm_type=attn_norm_layer, + num_features=transformer_dim, + ) + ) + self.global_rep = nn.Sequential(*global_rep) + + self.conv_proj = conv_1x1_out + + self.fusion = conv_3x3_out + + self.patch_h = patch_h + self.patch_w = patch_w + self.patch_area = self.patch_w * self.patch_h + + self.cnn_in_dim = in_channels + self.cnn_out_dim = transformer_dim + self.n_heads = num_heads + self.ffn_dim = ffn_dim + self.dropout = dropout + self.attn_dropout = attn_dropout + self.ffn_dropout = ffn_dropout + self.dilation = dilation + self.n_blocks = n_transformer_blocks + self.conv_ksize = conv_ksize + + def __repr__(self) -> str: + repr_str = "{}(".format(self.__class__.__name__) + + repr_str += "\n\t Global representations with patch size of {}x{}".format( + self.patch_h, self.patch_w + ) + if isinstance(self.global_rep, nn.Sequential): + for m in self.global_rep: + repr_str += "\n\t\t {}".format(m) + else: + repr_str += "\n\t\t {}".format(self.global_rep) + + if isinstance(self.conv_proj, nn.Sequential): + for m in self.conv_proj: + repr_str += "\n\t\t {}".format(m) + else: + repr_str += "\n\t\t {}".format(self.conv_proj) + + if self.fusion is not None: + repr_str += "\n\t Feature fusion" + if isinstance(self.fusion, nn.Sequential): + for m in self.fusion: + repr_str += "\n\t\t {}".format(m) + else: + repr_str += "\n\t\t {}".format(self.fusion) + + repr_str += "\n)" + return repr_str + + def forward_spatial(self, x: Tensor) -> Tensor: + res = x + + # fm = self.local_rep(x) + patches = x + + # b, c, h, w = fm.size() + # patches = einops.rearrange(fm, 'b c h w -> b (h w) c') + + # learn global representations + for transformer_layer in self.global_rep: + patches = transformer_layer(patches) + + # fm = einops.rearrange(patches, 'b (h w) c -> b c h w', h=h, w=w) + + fm = self.conv_proj(patches) + + if self.fusion is not None: + fm = self.fusion(torch.cat((res, fm), dim=1)) + return fm + + def forward_temporal( + self, x: Tensor, x_prev: Optional[Tensor] = None + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + + res = x + fm = self.local_rep(x) + + # # convert feature map to patches + # patches, info_dict = self.unfolding(fm) + + # learn global representations + for global_layer in self.global_rep: + if isinstance(global_layer, TransformerEncoder): + patches = global_layer(x=patches, x_prev=x_prev) + else: + patches = global_layer(patches) + + # # [B x Patch x Patches x C] --> [B x C x Patches x Patch] + # fm = self.folding(patches=patches, info_dict=info_dict) + + fm = self.conv_proj(fm) + + if self.fusion is not None: + fm = self.fusion(torch.cat((res, fm), dim=1)) + return fm, patches + + def forward( + self, x: Union[Tensor, Tuple[Tensor]], *args, **kwargs + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + if isinstance(x, Tuple) and len(x) == 2: + # for spatio-temporal MobileViT + return self.forward_temporal(x=x[0], x_prev=x[1]) + elif isinstance(x, Tensor): + # For image data + return self.forward_spatial(x) + else: + raise NotImplementedError + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + params = macs = 0.0 + + res = input + + b, c, h, w = input.size() + + out, p, m = module_profile(module=self.global_rep, x=input) + params += p + macs += m + + out, p, m = module_profile(module=self.conv_proj, x=out) + params += p + macs += m + + if self.fusion is not None: + out, p, m = module_profile( + module=self.fusion, x=torch.cat((out, res), dim=1) + ) + params += p + macs += m + + return res, params, macs + diff --git a/Adaptive Frequency Filters/affnet/modules/aspp_block.py b/Adaptive Frequency Filters/affnet/modules/aspp_block.py new file mode 100644 index 0000000..6770c6b --- /dev/null +++ b/Adaptive Frequency Filters/affnet/modules/aspp_block.py @@ -0,0 +1,267 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +from typing import Optional, Tuple +import torch.nn.functional as F + +from utils import logger +from utils.ddp_utils import is_master + +from ..layers import BaseLayer, ConvLayer, AdaptiveAvgPool2d, SeparableConv, Dropout2d +from ..modules import BaseModule +from ..misc.profiler import module_profile + + +class ASPP(BaseModule): + """ + ASPP module defined in DeepLab papers, `here `_ and `here `_ + + Args: + opts: command-line arguments + in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)` + out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H, W)` + atrous_rates (Tuple[int]): atrous rates for different branches. + is_sep_conv (Optional[bool]): Use separable convolution instead of standaard conv. Default: False + dropout (Optional[float]): Apply dropout. Default is 0.0 + + Shape: + - Input: :math:`(N, C_{in}, H, W)` + - Output: :math:`(N, C_{out}, H, W)` + """ + + def __init__( + self, + opts, + in_channels: int, + out_channels: int, + atrous_rates: Tuple[int], + is_sep_conv: Optional[bool] = False, + dropout: Optional[float] = 0.0, + *args, + **kwargs + ) -> None: + in_proj = ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + use_norm=True, + use_act=True, + ) + out_proj = ConvLayer( + opts=opts, + in_channels=5 * out_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + use_norm=True, + use_act=True, + ) + aspp_layer = ASPPSeparableConv if is_sep_conv else ASPPConv + + assert len(atrous_rates) == 3 + + modules = [in_proj] + modules.extend( + [ + aspp_layer( + opts=opts, + in_channels=in_channels, + out_channels=out_channels, + dilation=rate, + ) + for rate in atrous_rates + ] + ) + modules.append( + ASPPPooling(opts=opts, in_channels=in_channels, out_channels=out_channels) + ) + + if not (0.0 <= dropout < 1.0): + if is_master(opts): + logger.warning( + "Dropout value in {} should be between 0 and 1. Got: {}. Setting it to 0.0".format( + self.__class__.__name__, dropout + ) + ) + dropout = 0.0 + + super().__init__() + self.convs = nn.ModuleList(modules) + self.project = out_proj + + self.in_channels = in_channels + self.out_channels = out_channels + self.atrous_rates = atrous_rates + self.is_sep_conv_layer = is_sep_conv + self.n_atrous_branches = len(atrous_rates) + self.dropout_layer = Dropout2d(p=dropout) + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: + out = [] + for conv in self.convs: + out.append(conv(x)) + out = torch.cat(out, dim=1) + out = self.project(out) + out = self.dropout_layer(out) + return out + + def profile_module(self, input: Tensor, *args, **kwargs) -> (Tensor, float, float): + params, macs = 0.0, 0.0 + res = [] + for c in self.convs: + out, p, m = module_profile(module=c, x=input) + params += p + macs += m + res.append(out) + res = torch.cat(res, dim=1) + + out, p, m = module_profile(module=self.project, x=res) + params += p + macs += m + return out, params, macs + + def __repr__(self): + return "{}(in_channels={}, out_channels={}, atrous_rates={}, is_aspp_sep={}, dropout={})".format( + self.__class__.__name__, + self.in_channels, + self.out_channels, + self.atrous_rates, + self.is_sep_conv_layer, + self.dropout_layer.p, + ) + + +class ASPPConv(ConvLayer): + """ + Convolution with a dilation for the ASPP module + Args: + opts: command-line arguments + in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)` + out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H, W)` + dilation (int): Dilation rate + + Shape: + - Input: :math:`(N, C_{in}, H, W)` + - Output: :math:`(N, C_{out}, H, W)` + """ + + def __init__( + self, opts, in_channels: int, out_channels: int, dilation: int, *args, **kwargs + ) -> None: + super().__init__( + opts=opts, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + use_norm=True, + use_act=True, + dilation=dilation, + ) + + def adjust_atrous_rate(self, rate: int) -> None: + """This function allows to adjust the dilation rate""" + self.block.conv.dilation = rate + # padding is the same here + # see ConvLayer to see the method for computing padding + self.block.conv.padding = rate + + +class ASPPSeparableConv(SeparableConv): + """ + Separable Convolution with a dilation for the ASPP module + Args: + opts: command-line arguments + in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)` + out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H, W)` + dilation (int): Dilation rate + + Shape: + - Input: :math:`(N, C_{in}, H, W)` + - Output: :math:`(N, C_{out}, H, W)` + """ + + def __init__( + self, opts, in_channels: int, out_channels: int, dilation: int, *args, **kwargs + ) -> None: + super().__init__( + opts=opts, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + dilation=dilation, + use_norm=True, + use_act=True, + ) + + def adjust_atrous_rate(self, rate: int) -> None: + """This function allows to adjust the dilation rate""" + self.dw_conv.block.conv.dilation = rate + # padding is the same here + # see ConvLayer to see the method for computing padding + self.dw_conv.block.conv.padding = rate + + +class ASPPPooling(BaseLayer): + """ + ASPP pooling layer + Args: + opts: command-line arguments + in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)` + out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H, W)` + + Shape: + - Input: :math:`(N, C_{in}, H, W)` + - Output: :math:`(N, C_{out}, H, W)` + """ + + def __init__( + self, opts, in_channels: int, out_channels: int, *args, **kwargs + ) -> None: + + super().__init__() + self.aspp_pool = nn.Sequential() + self.aspp_pool.add_module( + name="global_pool", module=AdaptiveAvgPool2d(output_size=1) + ) + self.aspp_pool.add_module( + name="conv_1x1", + module=ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + use_norm=True, + use_act=True, + ), + ) + + self.in_channels = in_channels + self.out_channels = out_channels + + def forward(self, x: Tensor) -> Tensor: + x_size = x.shape[-2:] + x = self.aspp_pool(x) + x = F.interpolate(x, size=x_size, mode="bilinear", align_corners=False) + return x + + def profile_module(self, input: Tensor) -> Tuple[Tensor, float, float]: + out, params, macs = module_profile(module=self.aspp_pool, x=input) + out = F.interpolate( + out, size=input.shape[-2:], mode="bilinear", align_corners=False + ) + return out, params, macs + + def __repr__(self): + return "{}(in_channels={}, out_channels={})".format( + self.__class__.__name__, self.in_channels, self.out_channels + ) diff --git a/Adaptive Frequency Filters/affnet/modules/base_module.py b/Adaptive Frequency Filters/affnet/modules/base_module.py new file mode 100644 index 0000000..c8e74ab --- /dev/null +++ b/Adaptive Frequency Filters/affnet/modules/base_module.py @@ -0,0 +1,25 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +from typing import Tuple, Union, Any + + +class BaseModule(nn.Module): + """Base class for all modules""" + + def __init__(self, *args, **kwargs): + super(BaseModule, self).__init__() + + def forward(self, x: Any, *args, **kwargs) -> Any: + raise NotImplementedError + + def profile_module(self, input: Any, *args, **kwargs) -> Tuple[Any, float, float]: + raise NotImplementedError + + def __repr__(self): + return "{}".format(self.__class__.__name__) diff --git a/Adaptive Frequency Filters/affnet/modules/cbam.py b/Adaptive Frequency Filters/affnet/modules/cbam.py new file mode 100644 index 0000000..ec51f4e --- /dev/null +++ b/Adaptive Frequency Filters/affnet/modules/cbam.py @@ -0,0 +1,211 @@ +import torch +import math +import torch.nn as nn +import torch.nn.functional as F + +from affnet.layers import ConvLayer, get_activation_fn, get_normalization_layer + + +# class BasicConv(nn.Module): +# def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): +# super(BasicConv, self).__init__() +# self.out_channels = out_planes +# self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) +# self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None +# self.relu = nn.ReLU() if relu else None +# +# def forward(self, x): +# x = self.conv(x) +# if self.bn is not None: +# x = self.bn(x) +# if self.relu is not None: +# x = self.relu(x) +# return x + +class Flatten(nn.Module): + def forward(self, x): + return x.view(x.size(0), -1) + +class View_as_complex(nn.Module): + def forward(self, x): + return torch.view_as_complex(x) + +class View_as_real(nn.Module): + def forward(self, x): + return torch.view_as_real(x) + +class ChannelGateComplex(nn.Module): + + def __init__(self, opts, gate_channels, reduction_ratio=16): + super(ChannelGateComplex, self).__init__() + self.gate_channels = gate_channels + self.mlp = nn.Sequential( + nn.Conv2d(gate_channels, gate_channels // reduction_ratio, kernel_size=1, dtype=torch.complex64), + View_as_real(), + self.build_act_layer(opts=opts), + View_as_complex(), + nn.Conv2d(gate_channels // reduction_ratio, gate_channels, kernel_size=1, dtype=torch.complex64) + ) + + @staticmethod + def build_act_layer(opts) -> nn.Module: + act_type = getattr(opts, "model.activation.name", "relu") + neg_slope = getattr(opts, "model.activation.neg_slope", 0.1) + inplace = getattr(opts, "model.activation.inplace", False) + act_layer = get_activation_fn( + act_type=act_type, + inplace=inplace, + negative_slope=neg_slope, + num_parameters=1, + ) + return act_layer + + def forward(self, x): + # input complex + input = x + x = torch.view_as_real(x) + avg_pool = torch.mean(x, dim=(2, 3), keepdim=True) + max_pool = torch.amax(x, dim=[2, 3], keepdim=True) + avg_pool = torch.view_as_complex(avg_pool) + max_pool = torch.view_as_complex(max_pool) + avg_pool = self.mlp(avg_pool) + max_pool = self.mlp(max_pool) + channel_att_sum = torch.view_as_real(avg_pool + max_pool) + output = torch.view_as_complex(F.sigmoid(channel_att_sum)) * input + + return output + +class ChannelGate(nn.Module): + + def __init__(self, opts, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): + super(ChannelGate, self).__init__() + self.gate_channels = gate_channels + self.mlp = nn.Sequential( + Flatten(), + nn.Linear(gate_channels, gate_channels // reduction_ratio), + self.build_act_layer(opts=opts), + nn.Linear(gate_channels // reduction_ratio, gate_channels) + ) + self.pool_types = pool_types + + @staticmethod + def build_act_layer(opts) -> nn.Module: + act_type = getattr(opts, "model.activation.name", "relu") + neg_slope = getattr(opts, "model.activation.neg_slope", 0.1) + inplace = getattr(opts, "model.activation.inplace", False) + act_layer = get_activation_fn( + act_type=act_type, + inplace=inplace, + negative_slope=neg_slope, + num_parameters=1, + ) + return act_layer + + def forward(self, x): + channel_att_sum = None + for pool_type in self.pool_types: + if pool_type=='avg': + avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp( avg_pool ) + elif pool_type=='max': + max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp( max_pool ) + elif pool_type=='lp': + lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp( lp_pool ) + elif pool_type=='lse': + # LSE pool only + lse_pool = logsumexp_2d(x) + channel_att_raw = self.mlp( lse_pool ) + + if channel_att_sum is None: + channel_att_sum = channel_att_raw + else: + channel_att_sum = channel_att_sum + channel_att_raw + + scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) + return x * scale + +def logsumexp_2d(tensor): + tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) + s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) + outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() + return outputs + +class ChannelPool(nn.Module): + def forward(self, x): + return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) + +class ChannelPoolComplex(nn.Module): + def forward(self, x): + x = torch.view_as_real(x) + x = torch.cat( (torch.max(x, 1, keepdim=True)[0], torch.mean(x, 1, keepdim=True)), dim=1) + return torch.view_as_complex(x) + +class SpatialGateComplex(nn.Module): + def __init__(self, opts): + super(SpatialGateComplex, self).__init__() + kernel_size = 7 + self.compress = ChannelPoolComplex() + # self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) + self.spatial = nn.Conv2d(2, 1, kernel_size=kernel_size, stride=1, padding=(kernel_size-1) // 2, + dtype=torch.complex64) + self.spatial_bn = nn.BatchNorm3d(1, eps=1e-5, momentum=0.01, affine=True) + def forward(self, x): + input = x + x_compress = self.compress(x) + x_out = self.spatial(x_compress) + x_out = torch.view_as_real(x_out) + x_out = self.spatial_bn(x_out) + scale = F.sigmoid(x_out) # broadcasting + scale = torch.view_as_complex(scale) + return input * scale +class SpatialGate(nn.Module): + def __init__(self, opts): + super(SpatialGate, self).__init__() + kernel_size = 7 + self.compress = ChannelPool() + # self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) + self.spatial = ConvLayer(opts=opts, in_channels=2, out_channels=1, kernel_size=kernel_size, + stride=1, padding=(kernel_size-1) // 2, use_norm=True, use_act=False) + def forward(self, x): + x_compress = self.compress(x) + x_out = self.spatial(x_compress) + scale = F.sigmoid(x_out) # broadcasting + return x * scale + +class CBAM(nn.Module): + + def __init__(self, opts, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False, no_channel=False): + super(CBAM, self).__init__() + self.no_channel = no_channel + if not no_channel: + self.ChannelGate = ChannelGate(opts, gate_channels, reduction_ratio, pool_types) + self.no_spatial=no_spatial + if not no_spatial: + self.SpatialGate = SpatialGate(opts=opts) + + def forward(self, x): + if not self.no_channel: + x_out = self.ChannelGate(x) + else: + x_out = x + if not self.no_spatial: + x_out = self.SpatialGate(x_out) + return x_out + + +class CBAMComplex(nn.Module): + + def __init__(self, opts, gate_channels, reduction_ratio=16, no_spatial=False): + super(CBAMComplex, self).__init__() + self.ChannelGate = ChannelGateComplex(opts, gate_channels, reduction_ratio) + self.no_spatial=no_spatial + if not no_spatial: + self.SpatialGate = SpatialGateComplex(opts=opts) + + def forward(self, x): + x_out = self.ChannelGate(x) + if not self.no_spatial: + x_out = self.SpatialGate(x_out) + return x_out \ No newline at end of file diff --git a/Adaptive Frequency Filters/affnet/modules/complexFunctions.py b/Adaptive Frequency Filters/affnet/modules/complexFunctions.py new file mode 100644 index 0000000..bd40164 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/modules/complexFunctions.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +@author: spopoff +""" + +from torch.nn.functional import relu, max_pool2d, avg_pool2d, dropout, dropout2d, interpolate, sigmoid, tanh +import torch + +def complex_matmul(A, B): + ''' + Performs the matrix product between two complex matricess + ''' + + outp_real = torch.matmul(A.real, B.real) - torch.matmul(A.imag, B.imag) + outp_imag = torch.matmul(A.real, B.imag) + torch.matmul(A.imag, B.real) + + return outp_real.type(torch.complex64) + 1j * outp_imag.type(torch.complex64) + +def complex_avg_pool2d(input, *args, **kwargs): + ''' + Perform complex average pooling. + ''' + absolute_value_real = avg_pool2d(input.real, *args, **kwargs) + absolute_value_imag = avg_pool2d(input.imag, *args, **kwargs) + + return absolute_value_real.type(torch.complex64)+1j*absolute_value_imag.type(torch.complex64) + +def complex_normalize(input): + ''' + Perform complex normalization + ''' + real_value, imag_value = input.real, input.imag + real_norm = (real_value - real_value.mean()) / real_value.std() + imag_norm = (imag_value - imag_value.mean()) / imag_value.std() + + return real_norm.type(torch.complex64) + 1j*imag_norm.type(torch.complex64) + +def complex_relu(input): + return relu(input.real).type(torch.complex64)+1j*relu(input.imag).type(torch.complex64) + +def complex_relu(input): + return relu(input.real).type(torch.complex64)+1j*relu(input.imag).type(torch.complex64) + +def complex_sigmoid(input): + return sigmoid(input.real).type(torch.complex64)+1j*sigmoid(input.imag).type(torch.complex64) + +def complex_tanh(input): + return tanh(input.real).type(torch.complex64)+1j*tanh(input.imag).type(torch.complex64) + +def complex_opposite(input): + return -(input.real).type(torch.complex64)+1j*(-(input.imag).type(torch.complex64)) + +def complex_stack(input, dim): + input_real = [x.real for x in input] + input_imag = [x.imag for x in input] + return torch.stack(input_real, dim).type(torch.complex64)+1j*torch.stack(input_imag, dim).type(torch.complex64) + +def _retrieve_elements_from_indices(tensor, indices): + flattened_tensor = tensor.flatten(start_dim=-2) + output = flattened_tensor.gather(dim=-1, index=indices.flatten(start_dim=-2)).view_as(indices) + return output + +def complex_upsample(input, size=None, scale_factor=None, mode='nearest', + align_corners=None, recompute_scale_factor=None): + ''' + Performs upsampling by separately interpolating the real and imaginary part and recombining + ''' + outp_real = interpolate(input.real, size=size, scale_factor=scale_factor, mode=mode, + align_corners=align_corners, recompute_scale_factor=recompute_scale_factor) + outp_imag = interpolate(input.imag, size=size, scale_factor=scale_factor, mode=mode, + align_corners=align_corners, recompute_scale_factor=recompute_scale_factor) + + return outp_real.type(torch.complex64) + 1j * outp_imag.type(torch.complex64) + +def complex_upsample2(input, size=None, scale_factor=None, mode='nearest', + align_corners=None, recompute_scale_factor=None): + ''' + Performs upsampling by separately interpolating the amplitude and phase part and recombining + ''' + outp_abs = interpolate(input.abs(), size=size, scale_factor=scale_factor, mode=mode, + align_corners=align_corners, recompute_scale_factor=recompute_scale_factor) + angle = torch.atan2(input.imag,input.real) + outp_angle = interpolate(angle, size=size, scale_factor=scale_factor, mode=mode, + align_corners=align_corners, recompute_scale_factor=recompute_scale_factor) + + return outp_abs \ + * (torch.cos(angle).type(torch.complex64)+1j*torch.sin(angle).type(torch.complex64)) + + +def complex_max_pool2d(input,kernel_size, stride=None, padding=0, + dilation=1, ceil_mode=False, return_indices=False): + ''' + Perform complex max pooling by selecting on the absolute value on the complex values. + ''' + absolute_value, indices = max_pool2d( + input.abs(), + kernel_size = kernel_size, + stride = stride, + padding = padding, + dilation = dilation, + ceil_mode = ceil_mode, + return_indices = True + ) + # performs the selection on the absolute values + absolute_value = absolute_value.type(torch.complex64) + # retrieve the corresponding phase value using the indices + # unfortunately, the derivative for 'angle' is not implemented + angle = torch.atan2(input.imag,input.real) + # get only the phase values selected by max pool + angle = _retrieve_elements_from_indices(angle, indices) + return absolute_value \ + * (torch.cos(angle).type(torch.complex64)+1j*torch.sin(angle).type(torch.complex64)) + +def complex_dropout(input, p=0.5, training=True): + # need to have the same dropout mask for real and imaginary part, + # this not a clean solution! + device = input.device + mask = torch.ones(*input.shape, dtype = torch.float32, device = device) + mask = dropout(mask, p, training)*1/(1-p) + mask.type(input.dtype) + return mask*input + + +def complex_dropout2d(input, p=0.5, training=True): + # need to have the same dropout mask for real and imaginary part, + # this not a clean solution! + device = input.device + mask = torch.ones(*input.shape, dtype = torch.float32, device = device) + mask = dropout2d(mask, p, training)*1/(1-p) + mask.type(input.dtype) + return mask*input \ No newline at end of file diff --git a/Adaptive Frequency Filters/affnet/modules/complexLayers.py b/Adaptive Frequency Filters/affnet/modules/complexLayers.py new file mode 100644 index 0000000..3e61acb --- /dev/null +++ b/Adaptive Frequency Filters/affnet/modules/complexLayers.py @@ -0,0 +1,466 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Tue Mar 19 10:30:02 2019 +@author: Sebastien M. Popoff +Based on https://openreview.net/forum?id=H1T2hmZAb +""" + +import torch +from torch.nn import Module, Parameter, init +from torch.nn import Conv2d, Linear, BatchNorm1d, BatchNorm2d +from torch.nn import ConvTranspose2d +from .complexFunctions import complex_relu, complex_max_pool2d, complex_avg_pool2d +from .complexFunctions import complex_dropout, complex_dropout2d + +def apply_complex(fr, fi, input): + return torch.complex( + (fr(input.real)-fi(input.imag)), + (fr(input.imag)+fi(input.real)) + ) + +class ComplexDropout(Module): + def __init__(self,p=0.5): + super(ComplexDropout,self).__init__() + self.p = p + + def forward(self,input): + if self.training: + return complex_dropout(input,self.p) + else: + return input + +class ComplexDropout2d(Module): + def __init__(self,p=0.5): + super(ComplexDropout2d,self).__init__() + self.p = p + + def forward(self,input): + if self.training: + return complex_dropout2d(input,self.p) + else: + return input + +class ComplexMaxPool2d(Module): + + def __init__(self,kernel_size, stride= None, padding = 0, + dilation = 1, return_indices = False, ceil_mode = False): + super(ComplexMaxPool2d,self).__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.ceil_mode = ceil_mode + self.return_indices = return_indices + + def forward(self,input): + return complex_max_pool2d(input,kernel_size = self.kernel_size, + stride = self.stride, padding = self.padding, + dilation = self.dilation, ceil_mode = self.ceil_mode, + return_indices = self.return_indices) + + +class ComplexAvgPool2d(Module): + + def __init__(self,kernel_size, stride= None, padding = 0, + dilation = 1, return_indices = False, ceil_mode = False): + super(ComplexAvgPool2d,self).__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.ceil_mode = ceil_mode + self.return_indices = return_indices + + def forward(self,input): + return complex_avg_pool2d(input,kernel_size = self.kernel_size, + stride = self.stride, padding = self.padding, + dilation = self.dilation, ceil_mode = self.ceil_mode, + return_indices = self.return_indices) + +class ComplexReLU(Module): + + def forward(self,input): + return complex_relu(input) + +class ComplexSigmoid(Module): + + def forward(self,input): + return complex_sigmoid(input) + +class ComplexTanh(Module): + + def forward(self,input): + return complex_tanh(input) + +class ComplexConvTranspose2d(Module): + + def __init__(self,in_channels, out_channels, kernel_size, stride=1, padding=0, + output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros'): + + super(ComplexConvTranspose2d, self).__init__() + + self.conv_tran_r = ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, + output_padding, groups, bias, dilation, padding_mode) + self.conv_tran_i = ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, + output_padding, groups, bias, dilation, padding_mode) + + + def forward(self,input): + return apply_complex(self.conv_tran_r, self.conv_tran_i, input) + +class ComplexConv2d(Module): + + def __init__(self,in_channels, out_channels, kernel_size=3, stride=1, padding = 0, + dilation=1, groups=1, bias=True): + super(ComplexConv2d, self).__init__() + self.conv_r = Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) + self.conv_i = Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) + + def forward(self,input): + return apply_complex(self.conv_r, self.conv_i, input) + +class ComplexLinear(Module): + + def __init__(self, in_features, out_features): + super(ComplexLinear, self).__init__() + self.fc_r = Linear(in_features, out_features) + self.fc_i = Linear(in_features, out_features) + + def forward(self, input): + return apply_complex(self.fc_r, self.fc_i, input) + + +class NaiveComplexBatchNorm1d(Module): + ''' + Naive approach to complex batch norm, perform batch norm independently on real and imaginary part. + ''' + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, \ + track_running_stats=True): + super(NaiveComplexBatchNorm1d, self).__init__() + self.bn_r = BatchNorm1d(num_features, eps, momentum, affine, track_running_stats) + self.bn_i = BatchNorm1d(num_features, eps, momentum, affine, track_running_stats) + + def forward(self,input): + return self.bn_r(input.real).type(torch.complex64) +1j*self.bn_i(input.imag).type(torch.complex64) + +class NaiveComplexBatchNorm2d(Module): + ''' + Naive approach to complex batch norm, perform batch norm independently on real and imaginary part. + ''' + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, \ + track_running_stats=True): + super(NaiveComplexBatchNorm2d, self).__init__() + self.bn_r = BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) + self.bn_i = BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) + + def forward(self,input): + return self.bn_r(input.real).type(torch.complex64) +1j*self.bn_i(input.imag).type(torch.complex64) + +class _ComplexBatchNorm(Module): + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, + track_running_stats=True): + super(_ComplexBatchNorm, self).__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats + if self.affine: + self.weight = Parameter(torch.Tensor(num_features,3)) + self.bias = Parameter(torch.Tensor(num_features,2)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + if self.track_running_stats: + self.register_buffer('running_mean', torch.zeros(num_features, dtype = torch.complex64)) + self.register_buffer('running_covar', torch.zeros(num_features,3)) + self.running_covar[:,0] = 1.4142135623730951 + self.running_covar[:,1] = 1.4142135623730951 + self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) + else: + self.register_parameter('running_mean', None) + self.register_parameter('running_covar', None) + self.register_parameter('num_batches_tracked', None) + self.reset_parameters() + + def reset_running_stats(self): + if self.track_running_stats: + self.running_mean.zero_() + self.running_covar.zero_() + self.running_covar[:,0] = 1.4142135623730951 + self.running_covar[:,1] = 1.4142135623730951 + self.num_batches_tracked.zero_() + + def reset_parameters(self): + self.reset_running_stats() + if self.affine: + init.constant_(self.weight[:,:2],1.4142135623730951) + init.zeros_(self.weight[:,2]) + init.zeros_(self.bias) + +class ComplexBatchNorm2d(_ComplexBatchNorm): + + def forward(self, input): + exponential_average_factor = 0.0 + + if self.training and self.track_running_stats: + if self.num_batches_tracked is not None: + self.num_batches_tracked += 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + if self.training or (not self.training and not self.track_running_stats): + # calculate mean of real and imaginary part + # mean does not support automatic differentiation for outputs with complex dtype. + mean_r = input.real.mean([0, 2, 3]).type(torch.complex64) + mean_i = input.imag.mean([0, 2, 3]).type(torch.complex64) + mean = mean_r + 1j*mean_i + else: + mean = self.running_mean + + if self.training and self.track_running_stats: + # update running mean + with torch.no_grad(): + self.running_mean = exponential_average_factor * mean \ + + (1 - exponential_average_factor) * self.running_mean + + input = input - mean[None, :, None, None] + + if self.training or (not self.training and not self.track_running_stats): + # Elements of the covariance matrix (biased for train) + n = input.numel() / input.size(1) + Crr = 1./n*input.real.pow(2).sum(dim=[0,2,3])+self.eps + Cii = 1./n*input.imag.pow(2).sum(dim=[0,2,3])+self.eps + Cri = (input.real.mul(input.imag)).mean(dim=[0,2,3]) + else: + Crr = self.running_covar[:,0]+self.eps + Cii = self.running_covar[:,1]+self.eps + Cri = self.running_covar[:,2]#+self.eps + + if self.training and self.track_running_stats: + with torch.no_grad(): + self.running_covar[:,0] = exponential_average_factor * Crr * n / (n - 1) \ + + (1 - exponential_average_factor) * self.running_covar[:,0] + + self.running_covar[:,1] = exponential_average_factor * Cii * n / (n - 1) \ + + (1 - exponential_average_factor) * self.running_covar[:,1] + + self.running_covar[:,2] = exponential_average_factor * Cri * n / (n - 1) \ + + (1 - exponential_average_factor) * self.running_covar[:,2] + + # calculate the inverse square root the covariance matrix + det = Crr*Cii-Cri.pow(2) + s = torch.sqrt(det) + t = torch.sqrt(Cii+Crr + 2 * s) + inverse_st = 1.0 / (s * t) + Rrr = (Cii + s) * inverse_st + Rii = (Crr + s) * inverse_st + Rri = -Cri * inverse_st + + input = (Rrr[None,:,None,None]*input.real+Rri[None,:,None,None]*input.imag).type(torch.complex64) \ + + 1j*(Rii[None,:,None,None]*input.imag+Rri[None,:,None,None]*input.real).type(torch.complex64) + + if self.affine: + input = (self.weight[None,:,0,None,None]*input.real+self.weight[None,:,2,None,None]*input.imag+ \ + self.bias[None,:,0,None,None]).type(torch.complex64) \ + +1j*(self.weight[None,:,2,None,None]*input.real+self.weight[None,:,1,None,None]*input.imag+ \ + self.bias[None,:,1,None,None]).type(torch.complex64) + + return input + + +class ComplexBatchNorm1d(_ComplexBatchNorm): + + def forward(self, input): + + exponential_average_factor = 0.0 + + + if self.training and self.track_running_stats: + if self.num_batches_tracked is not None: + self.num_batches_tracked += 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + if self.training or (not self.training and not self.track_running_stats): + # calculate mean of real and imaginary part + mean_r = input.real.mean(dim=0).type(torch.complex64) + mean_i = input.imag.mean(dim=0).type(torch.complex64) + mean = mean_r + 1j*mean_i + else: + mean = self.running_mean + + if self.training and self.track_running_stats: + # update running mean + with torch.no_grad(): + self.running_mean = exponential_average_factor * mean \ + + (1 - exponential_average_factor) * self.running_mean + + input = input - mean[None, ...] + + if self.training or (not self.training and not self.track_running_stats): + # Elements of the covariance matrix (biased for train) + n = input.numel() / input.size(1) + Crr = input.real.var(dim=0,unbiased=False)+self.eps + Cii = input.imag.var(dim=0,unbiased=False)+self.eps + Cri = (input.real.mul(input.imag)).mean(dim=0) + else: + Crr = self.running_covar[:,0]+self.eps + Cii = self.running_covar[:,1]+self.eps + Cri = self.running_covar[:,2] + + if self.training and self.track_running_stats: + self.running_covar[:,0] = exponential_average_factor * Crr * n / (n - 1) \ + + (1 - exponential_average_factor) * self.running_covar[:,0] + + self.running_covar[:,1] = exponential_average_factor * Cii * n / (n - 1) \ + + (1 - exponential_average_factor) * self.running_covar[:,1] + + self.running_covar[:,2] = exponential_average_factor * Cri * n / (n - 1) \ + + (1 - exponential_average_factor) * self.running_covar[:,2] + + # calculate the inverse square root the covariance matrix + det = Crr*Cii-Cri.pow(2) + s = torch.sqrt(det) + t = torch.sqrt(Cii+Crr + 2 * s) + inverse_st = 1.0 / (s * t) + Rrr = (Cii + s) * inverse_st + Rii = (Crr + s) * inverse_st + Rri = -Cri * inverse_st + + input = (Rrr[None,:]*input.real+Rri[None,:]*input.imag).type(torch.complex64) \ + + 1j*(Rii[None,:]*input.imag+Rri[None,:]*input.real).type(torch.complex64) + + if self.affine: + input = (self.weight[None,:,0]*input.real+self.weight[None,:,2]*input.imag+ \ + self.bias[None,:,0]).type(torch.complex64) \ + +1j*(self.weight[None,:,2]*input.real+self.weight[None,:,1]*input.imag+ \ + self.bias[None,:,1]).type(torch.complex64) + + + del Crr, Cri, Cii, Rrr, Rii, Rri, det, s, t + return input + +class ComplexGRUCell(Module): + """ + A GRU cell for complex-valued inputs + """ + + def __init__(self, input_length=10, hidden_length=20): + super(ComplexGRUCell, self).__init__() + self.input_length = input_length + self.hidden_length = hidden_length + + # reset gate components + self.linear_reset_w1 = ComplexLinear(self.input_length, self.hidden_length) + self.linear_reset_r1 = ComplexLinear(self.hidden_length, self.hidden_length) + + self.linear_reset_w2 = ComplexLinear(self.input_length, self.hidden_length) + self.linear_reset_r2 = ComplexLinear(self.hidden_length, self.hidden_length) + + # update gate components + self.linear_gate_w3 = ComplexLinear(self.input_length, self.hidden_length) + self.linear_gate_r3 = ComplexLinear(self.hidden_length, self.hidden_length) + + self.activation_gate = ComplexSigmoid() + self.activation_candidate = ComplexTanh() + + def reset_gate(self, x, h): + x_1 = self.linear_reset_w1(x) + h_1 = self.linear_reset_r1(h) + # gate update + reset = self.activation_gate(x_1 + h_1) + return reset + + def update_gate(self, x, h): + x_2 = self.linear_reset_w2(x) + h_2 = self.linear_reset_r2(h) + z = self.activation_gate(h_2 + x_2) + return z + + def update_component(self, x, h, r): + x_3 = self.linear_gate_w3(x) + h_3 = r * self.linear_gate_r3(h) # element-wise multiplication + gate_update = self.activation_candidate(x_3 + h_3) + return gate_update + + def forward(self, x, h): + # Equation 1. reset gate vector + r = self.reset_gate(x, h) + + # Equation 2: the update gate - the shared update gate vector z + z = self.update_gate(x, h) + + # Equation 3: The almost output component + n = self.update_component(x, h, r) + + # Equation 4: the new hidden state + h_new = (1 + complex_opposite(z)) * n + z * h # element-wise multiplication + + return h_new + +class ComplexBNGRUCell(Module): + """ + A BN-GRU cell for complex-valued inputs + """ + + def __init__(self, input_length=10, hidden_length=20): + super(ComplexBNGRUCell, self).__init__() + self.input_length = input_length + self.hidden_length = hidden_length + + # reset gate components + self.linear_reset_w1 = ComplexLinear(self.input_length, self.hidden_length) + self.linear_reset_r1 = ComplexLinear(self.hidden_length, self.hidden_length) + + self.linear_reset_w2 = ComplexLinear(self.input_length, self.hidden_length) + self.linear_reset_r2 = ComplexLinear(self.hidden_length, self.hidden_length) + + # update gate components + self.linear_gate_w3 = ComplexLinear(self.input_length, self.hidden_length) + self.linear_gate_r3 = ComplexLinear(self.hidden_length, self.hidden_length) + + self.activation_gate = ComplexSigmoid() + self.activation_candidate = ComplexTanh() + + self.bn = ComplexBatchNorm2d(1) + + def reset_gate(self, x, h): + x_1 = self.linear_reset_w1(x) + h_1 = self.linear_reset_r1(h) + # gate update + reset = self.activation_gate(self.bn(x_1) + self.bn(h_1)) + return reset + + def update_gate(self, x, h): + x_2 = self.linear_reset_w2(x) + h_2 = self.linear_reset_r2(h) + z = self.activation_gate(self.bn(h_2) + self.bn(x_2)) + return z + + def update_component(self, x, h, r): + x_3 = self.linear_gate_w3(x) + h_3 = r * self.bn(self.linear_gate_r3(h)) # element-wise multiplication + gate_update = self.activation_candidate(self.bn(self.bn(x_3) + h_3)) + return gate_update + + def forward(self, x, h): + # Equation 1. reset gate vector + r = self.reset_gate(x, h) + + # Equation 2: the update gate - the shared update gate vector z + z = self.update_gate(x, h) + + # Equation 3: The almost output component + n = self.update_component(x, h, r) + + # Equation 4: the new hidden state + h_new = (1 + complex_opposite(z)) * n + z * h # element-wise multiplication + + return h_new \ No newline at end of file diff --git a/Adaptive Frequency Filters/affnet/modules/efficientnet.py b/Adaptive Frequency Filters/affnet/modules/efficientnet.py new file mode 100644 index 0000000..56d79cf --- /dev/null +++ b/Adaptive Frequency Filters/affnet/modules/efficientnet.py @@ -0,0 +1,52 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Optional, Union, Tuple + +from ..layers import StochasticDepth + +from . import InvertedResidualSE + + +class EfficientNetBlock(InvertedResidualSE): + """ + This class implements a variant of the inverted residual block with squeeze-excitation unit, + as described in `MobileNetv3 `_ paper. This variant + includes stochastic depth, as used in `EfficientNet `_ paper. + + Args: + stochastic_depth_prob: float, + For other arguments, refer to the parent class. + + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` + """ + + def __init__(self, stochastic_depth_prob: float, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.stochastic_depth = StochasticDepth(p=stochastic_depth_prob, mode="row") + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: + y = self.block(x) + if self.use_res_connect: + # Pass the output through the stochastic layer module, potentially zeroing it. + y = self.stochastic_depth(y) + # residual connection + y = y + x + return y + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + return super().profile_module(input=input) + + def __repr__(self) -> str: + return ( + super().__repr__()[:-1] + + f", stochastic_depth_prob={self.stochastic_depth.p})" + ) diff --git a/Adaptive Frequency Filters/affnet/modules/feature_pyramid.py b/Adaptive Frequency Filters/affnet/modules/feature_pyramid.py new file mode 100644 index 0000000..908200c --- /dev/null +++ b/Adaptive Frequency Filters/affnet/modules/feature_pyramid.py @@ -0,0 +1,175 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +from typing import Dict, List +import torch.nn.functional as F + +from utils import logger + +from ..layers import ConvLayer, norm_layers_tuple +from ..modules import BaseModule +from ..misc.profiler import module_profile +from ..misc.init_utils import initialize_conv_layer, initialize_norm_layers + + +class FeaturePyramidNetwork(BaseModule): + """ + This class implements the `Feature Pyramid Network `_ module for object detection. + + Args: + opts: command-line arguments + in_channels (List[int]): List of channels at different output strides + output_strides (List[int]): Feature maps from these output strides will be used in FPN + out_channels (int): Output channels + + """ + + def __init__( + self, + opts, + in_channels: List[int], + output_strides: List[str], + out_channels: int, + *args, + **kwargs + ) -> None: + + if isinstance(in_channels, int): + in_channels = [in_channels] + if isinstance(output_strides, int): + output_strides = [output_strides] + + if len(in_channels) != len(output_strides): + logger.error( + "For {}, we need the length of input_channels to be the same as the length of output stride. " + "Got: {} and {}".format( + self.__class__.__name__, len(in_channels), len(output_strides) + ) + ) + assert len(in_channels) == len(output_strides) + super().__init__(*args, **kwargs) + + self.proj_layers = nn.ModuleDict() + self.nxn_convs = nn.ModuleDict() + + for os, in_channel in zip(output_strides, in_channels): + proj_layer = ConvLayer( + opts=opts, + in_channels=in_channel, + out_channels=out_channels, + kernel_size=1, + bias=False, + use_norm=True, + use_act=False, + ) + nxn_conv = ConvLayer( + opts=opts, + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + bias=False, + use_norm=True, + use_act=False, + ) + + self.proj_layers.add_module(name="os_{}".format(os), module=proj_layer) + self.nxn_convs.add_module(name="os_{}".format(os), module=nxn_conv) + + self.num_fpn_layers = len(in_channels) + self.out_channels = out_channels + self.in_channels = in_channels + self.output_strides = output_strides + + self.reset_weights() + + def reset_weights(self) -> None: + """Resets the weights of FPN layers""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + initialize_conv_layer(m, init_method="xavier_uniform") + elif isinstance(m, norm_layers_tuple): + initialize_norm_layers(m) + + def forward(self, x: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]: + assert len(x) == self.num_fpn_layers + + # dictionary to store results for fpn + fpn_out_dict = {"os_".format(os): None for os in self.output_strides} + + # process the last output stride + os_key = "os_{}".format(self.output_strides[-1]) + prev_x = self.proj_layers[os_key](x[os_key]) + prev_x = self.nxn_convs[os_key](prev_x) + fpn_out_dict[os_key] = prev_x + + remaining_output_strides = self.output_strides[:-1] + + # bottom-up processing + for os in remaining_output_strides[::-1]: + os_key = "os_{}".format(os) + # 1x1 conv + curr_x = self.proj_layers[os_key](x[os_key]) + # upsample + prev_x = F.interpolate(prev_x, size=curr_x.shape[-2:], mode="nearest") + # add + prev_x = curr_x + prev_x + prev_x = self.nxn_convs[os_key](prev_x) + fpn_out_dict[os_key] = prev_x + + return fpn_out_dict + + def profile_module( + self, input: Dict[str, Tensor], *args, **kwargs + ) -> (Dict[str, Tensor], float, float): + params, macs = 0.0, 0.0 + + # dictionary to store results for fpn + fpn_out_dict = {"os_{}".format(os): None for os in self.output_strides} + + # process the last output stride + os_key = "os_{}".format(self.output_strides[-1]) + prev_x, p, m = module_profile(module=self.proj_layers[os_key], x=input[os_key]) + params += p + macs += m + + prev_x, p, m = module_profile(module=self.nxn_convs[os_key], x=prev_x) + params += p + macs += m + + fpn_out_dict[os_key] = prev_x + + remaining_output_strides = self.output_strides[:-1] + + for os in remaining_output_strides[::-1]: + # 1x1 conv + os_key = "os_{}".format(os) + curr_x, p, m = module_profile( + module=self.proj_layers[os_key], x=input[os_key] + ) + params += p + macs += m + + # upsample + prev_x = F.interpolate(prev_x, size=curr_x.shape[-2:], mode="nearest") + # add + prev_x = curr_x + prev_x + prev_x, p, m = module_profile(module=self.nxn_convs[os_key], x=prev_x) + params += p + macs += m + + fpn_out_dict[os_key] = prev_x + + return fpn_out_dict, params, macs + + def __repr__(self): + return "{}(in_channels={}, output_strides={} out_channels={})".format( + self.__class__.__name__, + self.in_channels, + self.output_strides, + self.out_channels, + ) diff --git a/Adaptive Frequency Filters/affnet/modules/mobilenetv2.py b/Adaptive Frequency Filters/affnet/modules/mobilenetv2.py new file mode 100644 index 0000000..4bcdc23 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/modules/mobilenetv2.py @@ -0,0 +1,257 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Optional, Union, Tuple + +from utils.math_utils import make_divisible + +from . import BaseModule, SqueezeExcitation +from ..misc.profiler import module_profile +from ..layers import ConvLayer, get_activation_fn + + +class InvertedResidualSE(BaseModule): + """ + This class implements the inverted residual block with squeeze-excitation unit, as described in + `MobileNetv3 `_ paper + + Args: + opts: command-line arguments + in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})` + out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out)` + expand_ratio (Union[int, float]): Expand the input channels by this factor in depth-wise conv + dilation (Optional[int]): Use conv with dilation. Default: 1 + stride (Optional[int]): Use convolutions with a stride. Default: 1 + use_se (Optional[bool]): Use squeeze-excitation block. Default: False + act_fn_name (Optional[str]): Activation function name. Default: relu + se_scale_fn_name (Optional [str]): Scale activation function inside SE unit. Defaults to hard_sigmoid + kernel_size (Optional[int]): Kernel size in depth-wise convolution. Defaults to 3. + squeeze_factor (Optional[bool]): Squeezing factor in SE unit. Defaults to 4. + + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` + """ + + def __init__( + self, + opts, + in_channels: int, + out_channels: int, + expand_ratio: Union[int, float], + dilation: Optional[int] = 1, + stride: Optional[int] = 1, + use_se: Optional[bool] = False, + act_fn_name: Optional[str] = "relu", + se_scale_fn_name: Optional[str] = "hard_sigmoid", + kernel_size: Optional[int] = 3, + squeeze_factor: Optional[int] = 4, + *args, + **kwargs + ) -> None: + hidden_dim = make_divisible(int(round(in_channels * expand_ratio)), 8) + act_fn = get_activation_fn(act_type=act_fn_name, inplace=True) + + super().__init__() + + block = nn.Sequential() + if expand_ratio != 1: + block.add_module( + name="exp_1x1", + module=ConvLayer( + opts, + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + use_act=False, + use_norm=True, + ), + ) + block.add_module(name="act_fn_1", module=act_fn) + + block.add_module( + name="conv_3x3", + module=ConvLayer( + opts, + in_channels=hidden_dim, + out_channels=hidden_dim, + stride=stride, + kernel_size=kernel_size, + groups=hidden_dim, + use_act=False, + use_norm=True, + dilation=dilation, + ), + ) + block.add_module(name="act_fn_2", module=act_fn) + + if use_se: + se = SqueezeExcitation( + opts=opts, + in_channels=hidden_dim, + squeeze_factor=squeeze_factor, + scale_fn_name=se_scale_fn_name, + ) + block.add_module(name="se", module=se) + + block.add_module( + name="red_1x1", + module=ConvLayer( + opts, + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + use_act=False, + use_norm=True, + ), + ) + + self.block = block + self.in_channels = in_channels + self.out_channels = out_channels + self.exp = expand_ratio + self.dilation = dilation + self.use_se = use_se + self.stride = stride + self.act_fn_name = act_fn_name + self.kernel_size = kernel_size + self.use_res_connect = self.stride == 1 and in_channels == out_channels + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: + y = self.block(x) + return x + y if self.use_res_connect else y + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + return module_profile(module=self.block, x=input) + + def __repr__(self) -> str: + return "{}(in_channels={}, out_channels={}, stride={}, exp={}, dilation={}, use_se={}, kernel_size={}, act_fn={})".format( + self.__class__.__name__, + self.in_channels, + self.out_channels, + self.stride, + self.exp, + self.dilation, + self.use_se, + self.kernel_size, + self.act_fn_name, + ) + + +class InvertedResidual(BaseModule): + """ + This class implements the inverted residual block, as described in `MobileNetv2 `_ paper + + Args: + opts: command-line arguments + in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})` + out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out)` + stride (Optional[int]): Use convolutions with a stride. Default: 1 + expand_ratio (Union[int, float]): Expand the input channels by this factor in depth-wise conv + dilation (Optional[int]): Use conv with dilation. Default: 1 + skip_connection (Optional[bool]): Use skip-connection. Default: True + + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` + + .. note:: + If `in_channels =! out_channels` and `stride > 1`, we set `skip_connection=False` + + """ + + def __init__( + self, + opts, + in_channels: int, + out_channels: int, + stride: int, + expand_ratio: Union[int, float], + dilation: int = 1, + skip_connection: Optional[bool] = True, + *args, + **kwargs + ) -> None: + assert stride in [1, 2] + hidden_dim = make_divisible(int(round(in_channels * expand_ratio)), 8) + + super().__init__() + + block = nn.Sequential() + if expand_ratio != 1: + block.add_module( + name="exp_1x1", + module=ConvLayer( + opts, + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + use_act=True, + use_norm=True, + ), + ) + + block.add_module( + name="conv_3x3", + module=ConvLayer( + opts, + in_channels=hidden_dim, + out_channels=hidden_dim, + stride=stride, + kernel_size=3, + groups=hidden_dim, + use_act=True, + use_norm=True, + dilation=dilation, + ), + ) + + block.add_module( + name="red_1x1", + module=ConvLayer( + opts, + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + use_act=False, + use_norm=True, + ), + ) + + self.block = block + self.in_channels = in_channels + self.out_channels = out_channels + self.exp = expand_ratio + self.dilation = dilation + self.stride = stride + self.use_res_connect = ( + self.stride == 1 and in_channels == out_channels and skip_connection + ) + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: + if self.use_res_connect: + return x + self.block(x) + else: + return self.block(x) + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + return module_profile(module=self.block, x=input) + + def __repr__(self) -> str: + return "{}(in_channels={}, out_channels={}, stride={}, exp={}, dilation={}, skip_conn={})".format( + self.__class__.__name__, + self.in_channels, + self.out_channels, + self.stride, + self.exp, + self.dilation, + self.use_res_connect, + ) diff --git a/Adaptive Frequency Filters/affnet/modules/mobilevit_block.py b/Adaptive Frequency Filters/affnet/modules/mobilevit_block.py new file mode 100644 index 0000000..f555c9f --- /dev/null +++ b/Adaptive Frequency Filters/affnet/modules/mobilevit_block.py @@ -0,0 +1,724 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# ----- + +import numpy as np +from torch import nn, Tensor +import math +import torch +from torch.nn import functional as F +from typing import Optional, Dict, Tuple, Union, Sequence + +from .transformer import TransformerEncoder, LinearAttnFFN +from .base_module import BaseModule +from ..misc.profiler import module_profile +from ..layers import ConvLayer, get_normalization_layer + + +class MobileViTBlock(BaseModule): + """ + This class defines the `MobileViT block `_ + + Args: + opts: command line arguments + in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)` + transformer_dim (int): Input dimension to the transformer unit + ffn_dim (int): Dimension of the FFN block + n_transformer_blocks (Optional[int]): Number of transformer blocks. Default: 2 + head_dim (Optional[int]): Head dimension in the multi-head attention. Default: 32 + attn_dropout (Optional[float]): Dropout in multi-head attention. Default: 0.0 + dropout (Optional[float]): Dropout rate. Default: 0.0 + ffn_dropout (Optional[float]): Dropout between FFN layers in transformer. Default: 0.0 + patch_h (Optional[int]): Patch height for unfolding operation. Default: 8 + patch_w (Optional[int]): Patch width for unfolding operation. Default: 8 + transformer_norm_layer (Optional[str]): Normalization layer in the transformer block. Default: layer_norm + conv_ksize (Optional[int]): Kernel size to learn local representations in MobileViT block. Default: 3 + dilation (Optional[int]): Dilation rate in convolutions. Default: 1 + no_fusion (Optional[bool]): Do not combine the input and output feature maps. Default: False + """ + + def __init__( + self, + opts, + in_channels: int, + transformer_dim: int, + ffn_dim: int, + n_transformer_blocks: Optional[int] = 2, + head_dim: Optional[int] = 32, + attn_dropout: Optional[float] = 0.0, + dropout: Optional[int] = 0.0, + ffn_dropout: Optional[int] = 0.0, + patch_h: Optional[int] = 8, + patch_w: Optional[int] = 8, + transformer_norm_layer: Optional[str] = "layer_norm", + conv_ksize: Optional[int] = 3, + dilation: Optional[int] = 1, + no_fusion: Optional[bool] = False, + *args, + **kwargs + ) -> None: + conv_3x3_in = ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=in_channels, + kernel_size=conv_ksize, + stride=1, + use_norm=True, + use_act=True, + dilation=dilation, + ) + conv_1x1_in = ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=transformer_dim, + kernel_size=1, + stride=1, + use_norm=False, + use_act=False, + ) + + conv_1x1_out = ConvLayer( + opts=opts, + in_channels=transformer_dim, + out_channels=in_channels, + kernel_size=1, + stride=1, + use_norm=True, + use_act=True, + ) + conv_3x3_out = None + if not no_fusion: + conv_3x3_out = ConvLayer( + opts=opts, + in_channels=2 * in_channels, + out_channels=in_channels, + kernel_size=conv_ksize, + stride=1, + use_norm=True, + use_act=True, + ) + super().__init__() + self.local_rep = nn.Sequential() + self.local_rep.add_module(name="conv_3x3", module=conv_3x3_in) + self.local_rep.add_module(name="conv_1x1", module=conv_1x1_in) + + assert transformer_dim % head_dim == 0 + num_heads = transformer_dim // head_dim + + global_rep = [ + TransformerEncoder( + opts=opts, + embed_dim=transformer_dim, + ffn_latent_dim=ffn_dim, + num_heads=num_heads, + attn_dropout=attn_dropout, + dropout=dropout, + ffn_dropout=ffn_dropout, + transformer_norm_layer=transformer_norm_layer, + ) + for _ in range(n_transformer_blocks) + ] + global_rep.append( + get_normalization_layer( + opts=opts, + norm_type=transformer_norm_layer, + num_features=transformer_dim, + ) + ) + self.global_rep = nn.Sequential(*global_rep) + + self.conv_proj = conv_1x1_out + + self.fusion = conv_3x3_out + + self.patch_h = patch_h + self.patch_w = patch_w + self.patch_area = self.patch_w * self.patch_h + + self.cnn_in_dim = in_channels + self.cnn_out_dim = transformer_dim + self.n_heads = num_heads + self.ffn_dim = ffn_dim + self.dropout = dropout + self.attn_dropout = attn_dropout + self.ffn_dropout = ffn_dropout + self.dilation = dilation + self.n_blocks = n_transformer_blocks + self.conv_ksize = conv_ksize + + def __repr__(self) -> str: + repr_str = "{}(".format(self.__class__.__name__) + + repr_str += "\n\t Local representations" + if isinstance(self.local_rep, nn.Sequential): + for m in self.local_rep: + repr_str += "\n\t\t {}".format(m) + else: + repr_str += "\n\t\t {}".format(self.local_rep) + + repr_str += "\n\t Global representations with patch size of {}x{}".format( + self.patch_h, self.patch_w + ) + if isinstance(self.global_rep, nn.Sequential): + for m in self.global_rep: + repr_str += "\n\t\t {}".format(m) + else: + repr_str += "\n\t\t {}".format(self.global_rep) + + if isinstance(self.conv_proj, nn.Sequential): + for m in self.conv_proj: + repr_str += "\n\t\t {}".format(m) + else: + repr_str += "\n\t\t {}".format(self.conv_proj) + + if self.fusion is not None: + repr_str += "\n\t Feature fusion" + if isinstance(self.fusion, nn.Sequential): + for m in self.fusion: + repr_str += "\n\t\t {}".format(m) + else: + repr_str += "\n\t\t {}".format(self.fusion) + + repr_str += "\n)" + return repr_str + + def unfolding(self, feature_map: Tensor) -> Tuple[Tensor, Dict]: + patch_w, patch_h = self.patch_w, self.patch_h + patch_area = int(patch_w * patch_h) + batch_size, in_channels, orig_h, orig_w = feature_map.shape + + new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h) + new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w) + + interpolate = False + if new_w != orig_w or new_h != orig_h: + # Note: Padding can be done, but then it needs to be handled in attention function. + feature_map = F.interpolate( + feature_map, size=(new_h, new_w), mode="bilinear", align_corners=False + ) + interpolate = True + + # number of patches along width and height + num_patch_w = new_w // patch_w # n_w + num_patch_h = new_h // patch_h # n_h + num_patches = num_patch_h * num_patch_w # N + + # [B, C, H, W] --> [B * C * n_h, p_h, n_w, p_w] + reshaped_fm = feature_map.reshape( + batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w + ) + # [B * C * n_h, p_h, n_w, p_w] --> [B * C * n_h, n_w, p_h, p_w] + transposed_fm = reshaped_fm.transpose(1, 2) + # [B * C * n_h, n_w, p_h, p_w] --> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w + reshaped_fm = transposed_fm.reshape( + batch_size, in_channels, num_patches, patch_area + ) + # [B, C, N, P] --> [B, P, N, C] + transposed_fm = reshaped_fm.transpose(1, 3) + # [B, P, N, C] --> [BP, N, C] + patches = transposed_fm.reshape(batch_size * patch_area, num_patches, -1) + + info_dict = { + "orig_size": (orig_h, orig_w), + "batch_size": batch_size, + "interpolate": interpolate, + "total_patches": num_patches, + "num_patches_w": num_patch_w, + "num_patches_h": num_patch_h, + } + + return patches, info_dict + + def folding(self, patches: Tensor, info_dict: Dict) -> Tensor: + n_dim = patches.dim() + assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format( + patches.shape + ) + # [BP, N, C] --> [B, P, N, C] + patches = patches.contiguous().view( + info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1 + ) + + batch_size, pixels, num_patches, channels = patches.size() + num_patch_h = info_dict["num_patches_h"] + num_patch_w = info_dict["num_patches_w"] + + # [B, P, N, C] --> [B, C, N, P] + patches = patches.transpose(1, 3) + + # [B, C, N, P] --> [B*C*n_h, n_w, p_h, p_w] + feature_map = patches.reshape( + batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w + ) + # [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] + feature_map = feature_map.transpose(1, 2) + # [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W] + feature_map = feature_map.reshape( + batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w + ) + if info_dict["interpolate"]: + feature_map = F.interpolate( + feature_map, + size=info_dict["orig_size"], + mode="bilinear", + align_corners=False, + ) + return feature_map + + def forward_spatial(self, x: Tensor) -> Tensor: + res = x + + fm = self.local_rep(x) + + # convert feature map to patches + patches, info_dict = self.unfolding(fm) + + # learn global representations + for transformer_layer in self.global_rep: + patches = transformer_layer(patches) + + # [B x Patch x Patches x C] --> [B x C x Patches x Patch] + fm = self.folding(patches=patches, info_dict=info_dict) + + fm = self.conv_proj(fm) + + if self.fusion is not None: + fm = self.fusion(torch.cat((res, fm), dim=1)) + return fm + + def forward_temporal( + self, x: Tensor, x_prev: Optional[Tensor] = None + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + + res = x + fm = self.local_rep(x) + + # convert feature map to patches + patches, info_dict = self.unfolding(fm) + + # learn global representations + for global_layer in self.global_rep: + if isinstance(global_layer, TransformerEncoder): + patches = global_layer(x=patches, x_prev=x_prev) + else: + patches = global_layer(patches) + + # [B x Patch x Patches x C] --> [B x C x Patches x Patch] + fm = self.folding(patches=patches, info_dict=info_dict) + + fm = self.conv_proj(fm) + + if self.fusion is not None: + fm = self.fusion(torch.cat((res, fm), dim=1)) + return fm, patches + + def forward( + self, x: Union[Tensor, Tuple[Tensor]], *args, **kwargs + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + if isinstance(x, Tuple) and len(x) == 2: + # for spatio-temporal MobileViT + return self.forward_temporal(x=x[0], x_prev=x[1]) + elif isinstance(x, Tensor): + # For image data + return self.forward_spatial(x) + else: + raise NotImplementedError + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + params = macs = 0.0 + + res = input + out, p, m = module_profile(module=self.local_rep, x=input) + params += p + macs += m + + patches, info_dict = self.unfolding(feature_map=out) + + patches, p, m = module_profile(module=self.global_rep, x=patches) + params += p + macs += m + + fm = self.folding(patches=patches, info_dict=info_dict) + + out, p, m = module_profile(module=self.conv_proj, x=fm) + params += p + macs += m + + if self.fusion is not None: + out, p, m = module_profile( + module=self.fusion, x=torch.cat((out, res), dim=1) + ) + params += p + macs += m + + return res, params, macs + + +class MobileViTBlockv2(BaseModule): + """ + This class defines the `MobileViTv2 `_ block + + Args: + opts: command line arguments + in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)` + attn_unit_dim (int): Input dimension to the attention unit + ffn_multiplier (int): Expand the input dimensions by this factor in FFN. Default is 2. + n_attn_blocks (Optional[int]): Number of attention units. Default: 2 + attn_dropout (Optional[float]): Dropout in multi-head attention. Default: 0.0 + dropout (Optional[float]): Dropout rate. Default: 0.0 + ffn_dropout (Optional[float]): Dropout between FFN layers in transformer. Default: 0.0 + patch_h (Optional[int]): Patch height for unfolding operation. Default: 8 + patch_w (Optional[int]): Patch width for unfolding operation. Default: 8 + conv_ksize (Optional[int]): Kernel size to learn local representations in MobileViT block. Default: 3 + dilation (Optional[int]): Dilation rate in convolutions. Default: 1 + attn_norm_layer (Optional[str]): Normalization layer in the attention block. Default: layer_norm_2d + """ + + def __init__( + self, + opts, + in_channels: int, + attn_unit_dim: int, + ffn_multiplier: Optional[Union[Sequence[Union[int, float]], int, float]] = 2.0, + n_attn_blocks: Optional[int] = 2, + attn_dropout: Optional[float] = 0.0, + dropout: Optional[float] = 0.0, + ffn_dropout: Optional[float] = 0.0, + patch_h: Optional[int] = 8, + patch_w: Optional[int] = 8, + conv_ksize: Optional[int] = 3, + dilation: Optional[int] = 1, + attn_norm_layer: Optional[str] = "layer_norm_2d", + *args, + **kwargs + ) -> None: + cnn_out_dim = attn_unit_dim + + conv_3x3_in = ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=in_channels, + kernel_size=conv_ksize, + stride=1, + use_norm=True, + use_act=True, + dilation=dilation, + groups=in_channels, + ) + conv_1x1_in = ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=cnn_out_dim, + kernel_size=1, + stride=1, + use_norm=False, + use_act=False, + ) + + super(MobileViTBlockv2, self).__init__() + self.local_rep = nn.Sequential(conv_3x3_in, conv_1x1_in) + + self.global_rep, attn_unit_dim = self._build_attn_layer( + opts=opts, + d_model=attn_unit_dim, + ffn_mult=ffn_multiplier, + n_layers=n_attn_blocks, + attn_dropout=attn_dropout, + dropout=dropout, + ffn_dropout=ffn_dropout, + attn_norm_layer=attn_norm_layer, + ) + + self.conv_proj = ConvLayer( + opts=opts, + in_channels=cnn_out_dim, + out_channels=in_channels, + kernel_size=1, + stride=1, + use_norm=True, + use_act=False, + ) + + self.patch_h = patch_h + self.patch_w = patch_w + self.patch_area = self.patch_w * self.patch_h + + self.cnn_in_dim = in_channels + self.cnn_out_dim = cnn_out_dim + self.transformer_in_dim = attn_unit_dim + self.dropout = dropout + self.attn_dropout = attn_dropout + self.ffn_dropout = ffn_dropout + self.n_blocks = n_attn_blocks + self.conv_ksize = conv_ksize + self.enable_coreml_compatible_fn = getattr( + opts, "common.enable_coreml_compatible_module", False + ) + + if self.enable_coreml_compatible_fn: + # we set persistent to false so that these weights are not part of model's state_dict + self.register_buffer( + name="unfolding_weights", + tensor=self._compute_unfolding_weights(), + persistent=False, + ) + + def _compute_unfolding_weights(self) -> Tensor: + # [P_h * P_w, P_h * P_w] + weights = torch.eye(self.patch_h * self.patch_w, dtype=torch.float) + # [P_h * P_w, P_h * P_w] --> [P_h * P_w, 1, P_h, P_w] + weights = weights.reshape( + (self.patch_h * self.patch_w, 1, self.patch_h, self.patch_w) + ) + # [P_h * P_w, 1, P_h, P_w] --> [P_h * P_w * C, 1, P_h, P_w] + weights = weights.repeat(self.cnn_out_dim, 1, 1, 1) + return weights + + def _build_attn_layer( + self, + opts, + d_model: int, + ffn_mult: Union[Sequence, int, float], + n_layers: int, + attn_dropout: float, + dropout: float, + ffn_dropout: float, + attn_norm_layer: str, + *args, + **kwargs + ) -> Tuple[nn.Module, int]: + + if isinstance(ffn_mult, Sequence) and len(ffn_mult) == 2: + ffn_dims = ( + np.linspace(ffn_mult[0], ffn_mult[1], n_layers, dtype=float) * d_model + ) + elif isinstance(ffn_mult, Sequence) and len(ffn_mult) == 1: + ffn_dims = [ffn_mult[0] * d_model] * n_layers + elif isinstance(ffn_mult, (int, float)): + ffn_dims = [ffn_mult * d_model] * n_layers + else: + raise NotImplementedError + + # ensure that dims are multiple of 16 + ffn_dims = [int((d // 16) * 16) for d in ffn_dims] + + global_rep = [ + LinearAttnFFN( + opts=opts, + embed_dim=d_model, + ffn_latent_dim=ffn_dims[block_idx], + attn_dropout=attn_dropout, + dropout=dropout, + ffn_dropout=ffn_dropout, + norm_layer=attn_norm_layer, + ) + for block_idx in range(n_layers) + ] + global_rep.append( + get_normalization_layer( + opts=opts, norm_type=attn_norm_layer, num_features=d_model + ) + ) + + return nn.Sequential(*global_rep), d_model + + def __repr__(self) -> str: + repr_str = "{}(".format(self.__class__.__name__) + + repr_str += "\n\t Local representations" + if isinstance(self.local_rep, nn.Sequential): + for m in self.local_rep: + repr_str += "\n\t\t {}".format(m) + else: + repr_str += "\n\t\t {}".format(self.local_rep) + + repr_str += "\n\t Global representations with patch size of {}x{}".format( + self.patch_h, + self.patch_w, + ) + if isinstance(self.global_rep, nn.Sequential): + for m in self.global_rep: + repr_str += "\n\t\t {}".format(m) + else: + repr_str += "\n\t\t {}".format(self.global_rep) + + if isinstance(self.conv_proj, nn.Sequential): + for m in self.conv_proj: + repr_str += "\n\t\t {}".format(m) + else: + repr_str += "\n\t\t {}".format(self.conv_proj) + + repr_str += "\n)" + return repr_str + + def unfolding_pytorch(self, feature_map: Tensor) -> Tuple[Tensor, Tuple[int, int]]: + + batch_size, in_channels, img_h, img_w = feature_map.shape + + # [B, C, H, W] --> [B, C, P, N] + patches = F.unfold( + feature_map, + kernel_size=(self.patch_h, self.patch_w), + stride=(self.patch_h, self.patch_w), + ) + patches = patches.reshape( + batch_size, in_channels, self.patch_h * self.patch_w, -1 + ) + + return patches, (img_h, img_w) + + def folding_pytorch(self, patches: Tensor, output_size: Tuple[int, int]) -> Tensor: + batch_size, in_dim, patch_size, n_patches = patches.shape + + # [B, C, P, N] + patches = patches.reshape(batch_size, in_dim * patch_size, n_patches) + + feature_map = F.fold( + patches, + output_size=output_size, + kernel_size=(self.patch_h, self.patch_w), + stride=(self.patch_h, self.patch_w), + ) + + return feature_map + + def unfolding_coreml(self, feature_map: Tensor) -> Tuple[Tensor, Tuple[int, int]]: + # im2col is not implemented in Coreml, so here we hack its implementation using conv2d + # we compute the weights + + # [B, C, H, W] --> [B, C, P, N] + batch_size, in_channels, img_h, img_w = feature_map.shape + # + patches = F.conv2d( + feature_map, + self.unfolding_weights, + bias=None, + stride=(self.patch_h, self.patch_w), + padding=0, + dilation=1, + groups=in_channels, + ) + patches = patches.reshape( + batch_size, in_channels, self.patch_h * self.patch_w, -1 + ) + return patches, (img_h, img_w) + + def folding_coreml(self, patches: Tensor, output_size: Tuple[int, int]) -> Tensor: + # col2im is not supported on coreml, so tracing fails + # We hack folding function via pixel_shuffle to enable coreml tracing + batch_size, in_dim, patch_size, n_patches = patches.shape + + n_patches_h = output_size[0] // self.patch_h + n_patches_w = output_size[1] // self.patch_w + + feature_map = patches.reshape( + batch_size, in_dim * self.patch_h * self.patch_w, n_patches_h, n_patches_w + ) + assert ( + self.patch_h == self.patch_w + ), "For Coreml, we need patch_h and patch_w are the same" + feature_map = F.pixel_shuffle(feature_map, upscale_factor=self.patch_h) + return feature_map + + def resize_input_if_needed(self, x): + batch_size, in_channels, orig_h, orig_w = x.shape + if orig_h % self.patch_h != 0 or orig_w % self.patch_w != 0: + new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h) + new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w) + x = F.interpolate( + x, size=(new_h, new_w), mode="bilinear", align_corners=True + ) + return x + + def forward_spatial(self, x: Tensor, *args, **kwargs) -> Tensor: + x = self.resize_input_if_needed(x) + + fm = self.local_rep(x) + + # convert feature map to patches + if self.enable_coreml_compatible_fn: + patches, output_size = self.unfolding_coreml(fm) + else: + patches, output_size = self.unfolding_pytorch(fm) + + # learn global representations on all patches + patches = self.global_rep(patches) + + # [B x Patch x Patches x C] --> [B x C x Patches x Patch] + if self.enable_coreml_compatible_fn: + fm = self.folding_coreml(patches=patches, output_size=output_size) + else: + fm = self.folding_pytorch(patches=patches, output_size=output_size) + fm = self.conv_proj(fm) + + return fm + + def forward_temporal( + self, x: Tensor, x_prev: Tensor, *args, **kwargs + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + x = self.resize_input_if_needed(x) + + fm = self.local_rep(x) + + # convert feature map to patches + if self.enable_coreml_compatible_fn: + patches, output_size = self.unfolding_coreml(fm) + else: + patches, output_size = self.unfolding_pytorch(fm) + + # learn global representations + for global_layer in self.global_rep: + if isinstance(global_layer, LinearAttnFFN): + patches = global_layer(x=patches, x_prev=x_prev) + else: + patches = global_layer(patches) + + # [B x Patch x Patches x C] --> [B x C x Patches x Patch] + if self.enable_coreml_compatible_fn: + fm = self.folding_coreml(patches=patches, output_size=output_size) + else: + fm = self.folding_pytorch(patches=patches, output_size=output_size) + fm = self.conv_proj(fm) + + return fm, patches + + def forward( + self, x: Union[Tensor, Tuple[Tensor]], *args, **kwargs + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + if isinstance(x, Tuple) and len(x) == 2: + # for spatio-temporal data (e.g., videos) + return self.forward_temporal(x=x[0], x_prev=x[1]) + elif isinstance(x, Tensor): + # for image data + return self.forward_spatial(x) + else: + raise NotImplementedError + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + params = macs = 0.0 + input = self.resize_input_if_needed(input) + + res = input + out, p, m = module_profile(module=self.local_rep, x=input) + params += p + macs += m + + patches, output_size = self.unfolding_pytorch(feature_map=out) + + patches, p, m = module_profile(module=self.global_rep, x=patches) + params += p + macs += m + + fm = self.folding_pytorch(patches=patches, output_size=output_size) + + out, p, m = module_profile(module=self.conv_proj, x=fm) + params += p + macs += m + + return res, params, macs diff --git a/Adaptive Frequency Filters/affnet/modules/pspnet_module.py b/Adaptive Frequency Filters/affnet/modules/pspnet_module.py new file mode 100644 index 0000000..141cc4a --- /dev/null +++ b/Adaptive Frequency Filters/affnet/modules/pspnet_module.py @@ -0,0 +1,135 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +from typing import Optional, Sequence, Tuple +import torch.nn.functional as F + +from utils import logger + +from ..layers import ConvLayer, AdaptiveAvgPool2d, Dropout2d +from ..modules import BaseModule +from ..misc.profiler import module_profile + + +class PSP(BaseModule): + """ + This class defines the Pyramid Scene Parsing module in the `PSPNet paper `_ + + Args: + opts: command-line arguments + in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)` + out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H, W)` + pool_sizes Optional[Tuple[int, ...]]: List or Tuple of pool sizes. Default: (1, 2, 3, 6) + dropout (Optional[float]): Apply dropout. Default is 0.0 + """ + + def __init__( + self, + opts, + in_channels: int, + out_channels: int, + pool_sizes: Optional[Tuple[int, ...]] = (1, 2, 3, 6), + dropout: Optional[float] = 0.0, + *args, + **kwargs + ) -> None: + if not (0.0 <= dropout < 1.0): + logger.error( + "Dropout value in {} should be between 0 and 1. Got: {}".format( + self.__class__.__name__, dropout + ) + ) + reduction_dim = in_channels // len(pool_sizes) + reduction_dim = (reduction_dim // 16) * 16 + channels_after_concat = (reduction_dim * len(pool_sizes)) + in_channels + + super().__init__() + self.psp_branches = nn.ModuleList( + [ + self._make_psp_layer( + opts, o_size=ps, in_channels=in_channels, out_channels=reduction_dim + ) + for ps in pool_sizes + ] + ) + self.fusion = nn.Sequential( + ConvLayer( + opts=opts, + in_channels=channels_after_concat, + out_channels=out_channels, + kernel_size=3, + stride=1, + use_norm=True, + use_act=True, + ), + Dropout2d(p=dropout), + ) + + self.in_channels = in_channels + self.out_channels = out_channels + self.pool_sizes = pool_sizes + self.inner_channels = reduction_dim + self.dropout = dropout + + @staticmethod + def _make_psp_layer( + opts, o_size: int, in_channels: int, out_channels: int + ) -> nn.Module: + return nn.Sequential( + AdaptiveAvgPool2d(output_size=(o_size, o_size)), + ConvLayer( + opts, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + bias=False, + use_norm=True, + use_act=True, + ), + ) + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: + x_size = x.shape[2:] + out = [x] + [ + F.interpolate( + input=psp_branch(x), size=x_size, mode="bilinear", align_corners=True + ) + for psp_branch in self.psp_branches + ] + out = torch.cat(out, dim=1) + out = self.fusion(out) + return out + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + params, macs = 0.0, 0.0 + res = [input] + input_size = input.size() + for psp_branch in self.psp_branches: + out, p, m = module_profile(module=psp_branch, x=input) + out = F.interpolate( + out, input_size[2:], mode="bilinear", align_corners=True + ) + params += p + macs += m + res.append(out) + res = torch.cat(res, dim=1) + + res, p, m = module_profile(module=self.fusion, x=res) + return res, params + p, macs + m + + def __repr__(self): + return "{}(in_channels={}, out_channels={}, pool_sizes={}, inner_channels={}, dropout_2d={})".format( + self.__class__.__name__, + self.in_channels, + self.out_channels, + self.pool_sizes, + self.inner_channels, + self.dropout, + ) diff --git a/Adaptive Frequency Filters/affnet/modules/resnet_modules.py b/Adaptive Frequency Filters/affnet/modules/resnet_modules.py new file mode 100644 index 0000000..f825278 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/modules/resnet_modules.py @@ -0,0 +1,265 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Optional, Tuple + +from ..layers import ConvLayer, Identity, get_activation_fn, Dropout +from ..modules import BaseModule +from ..misc.profiler import module_profile + + +class BasicResNetBlock(BaseModule): + """ + This class defines the Basic block in the `ResNet model `_ + Args: + opts: command-line arguments + in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})` + mid_channels (int): :math:`C_{mid}` from an expected tensor of size :math:`(N, C_{mid}, H_{out}, W_{out})` + out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})` + stride (Optional[int]): Stride for convolution. Default: 1 + dilation (Optional[int]): Dilation for convolution. Default: 1 + dropout (Optional[float]): Dropout after second convolution. Default: 0.0 + + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` + + """ + + expansion: int = 1 + + def __init__( + self, + opts, + in_channels: int, + mid_channels: int, + out_channels: int, + stride: Optional[int] = 1, + dilation: Optional[int] = 1, + dropout: Optional[float] = 0.0, + *args, + **kwargs + ) -> None: + + act_type = getattr(opts, "model.activation.name", "relu") + neg_slope = getattr(opts, "model.activation.neg_slope", 0.1) + inplace = getattr(opts, "model.activation.inplace", False) + + cbr_1 = ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=3, + stride=stride, + dilation=dilation, + use_norm=True, + use_act=True, + ) + cb_2 = ConvLayer( + opts=opts, + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + use_norm=True, + use_act=False, + dilation=dilation, + ) + + block = nn.Sequential() + block.add_module(name="conv_batch_act_1", module=cbr_1) + block.add_module(name="conv_batch_2", module=cb_2) + if 0.0 < dropout < 1.0: + block.add_module(name="dropout", module=Dropout(p=dropout)) + + down_sample = Identity() + if stride == 2: + down_sample = ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=stride, + use_norm=True, + use_act=False, + ) + + super().__init__() + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.block = block + self.down_sample = down_sample + + self.final_act = get_activation_fn( + act_type=act_type, + inplace=inplace, + negative_slope=neg_slope, + num_parameters=out_channels, + ) + self.stride = stride + self.in_channels = in_channels + self.out_channels = out_channels + self.dilation = dilation + self.dropout = dropout + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: + out = self.block(x) + res = self.down_sample(x) + out = out + res + return self.final_act(out) + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + out, n_params, n_macs = module_profile(module=self.block, x=input) + _, n_params_down, n_macs_down = module_profile(module=self.down_sample, x=input) + return out, n_params + n_params_down, n_macs + n_macs_down + + def __repr__(self) -> str: + return "{}(in_channels={}, out_channels={}, stride={}, dilation={}, dropout={})".format( + self.__class__.__name__, + self.in_channels, + self.out_channels, + self.stride, + self.dilation, + self.dropout, + ) + + +class BottleneckResNetBlock(BaseModule): + """ + This class defines the Bottleneck block in the `ResNet model `_ + Args: + opts: command-line arguments + in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})` + mid_channels (int): :math:`C_{mid}` from an expected tensor of size :math:`(N, C_{mid}, H_{out}, W_{out})` + out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})` + stride (Optional[int]): Stride for convolution. Default: 1 + dilation (Optional[int]): Dilation for convolution. Default: 1 + dropout (Optional[float]): Dropout after third convolution. Default: 0.0 + + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` + + """ + + expansion: int = 4 + + def __init__( + self, + opts, + in_channels: int, + mid_channels: int, + out_channels: int, + stride: Optional[int] = 1, + dilation: Optional[int] = 1, + dropout: Optional[float] = 0.0, + *args, + **kwargs + ) -> None: + act_type = getattr(opts, "model.activation.name", "relu") + neg_slope = getattr(opts, "model.activation.neg_slope", 0.1) + inplace = getattr(opts, "model.activation.inplace", False) + + cbr_1 = ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + use_norm=True, + use_act=True, + ) + cbr_2 = ConvLayer( + opts=opts, + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=3, + stride=stride, + use_norm=True, + use_act=True, + dilation=dilation, + ) + cb_3 = ConvLayer( + opts=opts, + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + use_norm=True, + use_act=False, + ) + block = nn.Sequential() + block.add_module(name="conv_batch_act_1", module=cbr_1) + block.add_module(name="conv_batch_act_2", module=cbr_2) + block.add_module(name="conv_batch_3", module=cb_3) + if 0.0 < dropout < 1.0: + block.add_module(name="dropout", module=Dropout(p=dropout)) + + down_sample = Identity() + if stride == 2: + down_sample = ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=stride, + use_norm=True, + use_act=False, + ) + elif in_channels != out_channels: + down_sample = ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + use_norm=True, + use_act=False, + ) + + super().__init__() + self.block = block + + self.down_sample = down_sample + self.final_act = get_activation_fn( + act_type=act_type, + inplace=inplace, + negative_slope=neg_slope, + num_parameters=out_channels, + ) + + self.stride = stride + self.in_channels = in_channels + self.out_channels = out_channels + self.mid_channels = mid_channels + self.dilation = dilation + self.dropout = dropout + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: + out = self.block(x) + res = self.down_sample(x) + out = out + res + return self.final_act(out) + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + out, n_params, n_macs = module_profile(module=self.block, x=input) + _, n_params_down, n_macs_down = module_profile(module=self.down_sample, x=input) + return out, n_params + n_params_down, n_macs + n_macs_down + + def __repr__(self) -> str: + return "{}(in_channels={}, mid_channels={}, out_channels={}, stride={}, dilation={}, dropout={})".format( + self.__class__.__name__, + self.in_channels, + self.mid_channels, + self.out_channels, + self.stride, + self.dilation, + self.dropout, + ) diff --git a/Adaptive Frequency Filters/affnet/modules/squeeze_excitation.py b/Adaptive Frequency Filters/affnet/modules/squeeze_excitation.py new file mode 100644 index 0000000..1e0dc1c --- /dev/null +++ b/Adaptive Frequency Filters/affnet/modules/squeeze_excitation.py @@ -0,0 +1,90 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +from typing import Optional +from utils.math_utils import make_divisible + +from ..layers import AdaptiveAvgPool2d, ConvLayer, get_activation_fn +from ..modules import BaseModule +from ..misc.profiler import module_profile + + +class SqueezeExcitation(BaseModule): + """ + This class defines the Squeeze-excitation module, in the `SENet paper `_ + + Args: + opts: command-line arguments + in_channels (int): :math:`C` from an expected input of size :math:`(N, C, H, W)` + squeeze_factor (Optional[int]): Reduce :math:`C` by this factor. Default: 4 + scale_fn_name (Optional[str]): Scaling function name. Default: sigmoid + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` + """ + + def __init__( + self, + opts, + in_channels: int, + squeeze_factor: Optional[int] = 4, + scale_fn_name: Optional[str] = "sigmoid", + *args, + **kwargs + ) -> None: + squeeze_channels = max(make_divisible(in_channels // squeeze_factor, 8), 32) + + fc1 = ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=squeeze_channels, + kernel_size=1, + stride=1, + bias=True, + use_norm=False, + use_act=True, + ) + fc2 = ConvLayer( + opts=opts, + in_channels=squeeze_channels, + out_channels=in_channels, + kernel_size=1, + stride=1, + bias=True, + use_norm=False, + use_act=False, + ) + act_fn = get_activation_fn(act_type=scale_fn_name, inplace=True) + super().__init__() + self.se_layer = nn.Sequential() + self.se_layer.add_module( + name="global_pool", module=AdaptiveAvgPool2d(output_size=1) + ) + self.se_layer.add_module(name="fc1", module=fc1) + self.se_layer.add_module(name="fc2", module=fc2) + self.se_layer.add_module(name="scale_act", module=act_fn) + + self.in_channels = in_channels + self.squeeze_factor = squeeze_factor + self.scale_fn = scale_fn_name + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: + return x * self.se_layer(x) + + def profile_module(self, input: Tensor, *args, **kwargs) -> (Tensor, float, float): + _, params, macs = module_profile(module=self.se_layer, x=input) + return input, params, macs + + def __repr__(self) -> str: + return "{}(in_channels={}, squeeze_factor={}, scale_fn={})".format( + self.__class__.__name__, + self.in_channels, + self.squeeze_factor, + self.scale_fn, + ) diff --git a/Adaptive Frequency Filters/affnet/modules/ssd_heads.py b/Adaptive Frequency Filters/affnet/modules/ssd_heads.py new file mode 100644 index 0000000..6f19284 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/modules/ssd_heads.py @@ -0,0 +1,263 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +from typing import Optional, Tuple +from torchvision.ops.roi_align import RoIAlign + +from ..layers import ConvLayer, SeparableConv, TransposeConvLayer +from ..modules import BaseModule +from ..misc.profiler import module_profile +from ..misc.init_utils import initialize_conv_layer + + +class SSDHead(BaseModule): + """ + This class defines the `SSD object detection Head `_ + + Args: + opts: command-line arguments + in_channels (int): :math:`C` from an expected input of size :math:`(N, C, H, W)` + n_anchors (int): Number of anchors + n_classes (int): Number of classes in the dataset + n_coordinates (Optional[int]): Number of coordinates. Default: 4 (x, y, w, h) + proj_channels (Optional[int]): Number of projected channels. If `-1`, then projection layer is not used + kernel_size (Optional[int]): Kernel size in convolutional layer. If kernel_size=1, then standard + point-wise convolution is used. Otherwise, separable convolution is used + stride (Optional[int]): stride for feature map. If stride > 1, then feature map is sampled at this rate + and predictions are made on fewer pixels as compared to the input tensor. Default: 1 + """ + + def __init__( + self, + opts, + in_channels: int, + n_anchors: int, + n_classes: int, + n_coordinates: Optional[int] = 4, + proj_channels: Optional[int] = -1, + kernel_size: Optional[int] = 3, + stride: Optional[int] = 1, + *args, + **kwargs + ) -> None: + super().__init__() + proj_layer = None + self.proj_channels = None + if proj_channels != -1 and proj_channels != in_channels and kernel_size > 1: + proj_layer = ConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=proj_channels, + kernel_size=1, + stride=1, + groups=1, + bias=False, + use_norm=True, + use_act=True, + ) + in_channels = proj_channels + self.proj_channels = proj_channels + + self.proj_layer = proj_layer + + conv_fn = ConvLayer if kernel_size == 1 else SeparableConv + if kernel_size > 1 and stride > 1: + kernel_size = max(kernel_size, stride if stride % 2 != 0 else stride + 1) + self.loc_cls_layer = conv_fn( + opts=opts, + in_channels=in_channels, + out_channels=n_anchors * (n_coordinates + n_classes), + kernel_size=kernel_size, + stride=1, + groups=1, + bias=True, + use_norm=False, + use_act=False, + ) + + self.n_coordinates = n_coordinates + self.n_classes = n_classes + self.n_anchors = n_anchors + self.k_size = kernel_size + self.stride = stride + self.in_channel = in_channels + + self.reset_parameters() + + def __repr__(self) -> str: + repr_str = "{}(in_channels={}, n_anchors={}, n_classes={}, n_coordinates={}, kernel_size={}, stride={}".format( + self.__class__.__name__, + self.in_channel, + self.n_anchors, + self.n_classes, + self.n_coordinates, + self.k_size, + self.stride, + ) + if self.proj_layer is not None: + repr_str += ", proj=True, proj_channels={}".format(self.proj_channels) + + repr_str += ")" + return repr_str + + def reset_parameters(self) -> None: + for layer in self.modules(): + if isinstance(layer, nn.Conv2d): + initialize_conv_layer(module=layer, init_method="xavier_uniform") + + def _sample_fm(self, x: Tensor) -> Tensor: + height, width = x.shape[-2:] + device = x.device + start_step = max(0, self.stride // 2) + indices_h = torch.arange( + start=start_step, + end=height, + step=self.stride, + dtype=torch.int64, + device=device, + ) + indices_w = torch.arange( + start=start_step, + end=width, + step=self.stride, + dtype=torch.int64, + device=device, + ) + + x_sampled = torch.index_select(x, dim=-1, index=indices_w) + x_sampled = torch.index_select(x_sampled, dim=-2, index=indices_h) + return x_sampled + + def forward(self, x: Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor]: + batch_size = x.shape[0] + + if self.proj_layer is not None: + x = self.proj_layer(x) + + # [B x C x H x W] --> [B x Anchors * (coordinates + classes) x H x W] + x = self.loc_cls_layer(x) + + if self.stride > 1: + x = self._sample_fm(x) + + # [B x Anchors * (coordinates + classes) x H x W] --> [B x H x W x Anchors * (coordinates + classes)] + x = x.permute(0, 2, 3, 1) + # [B x H x W x Anchors * (coordinates + classes)] --> [B x H*W*Anchors X (coordinates + classes)] + x = x.contiguous().view(batch_size, -1, self.n_coordinates + self.n_classes) + + # [B x H*W*Anchors X (coordinates + classes)] --> [B x H*W*Anchors X coordinates], [B x H*W*Anchors X classes] + box_locations, box_classes = torch.split( + x, [self.n_coordinates, self.n_classes], dim=-1 + ) + return box_locations, box_classes + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + params = macs = 0.0 + + if self.proj_layer is not None: + input, p, m = module_profile(module=self.proj_layer, x=input) + params += p + macs += m + + x, p, m = module_profile(module=self.loc_cls_layer, x=input) + params += p + macs += m + + return input, params, macs + + +class SSDInstanceHead(BaseModule): + """ + Instance segmentation head for SSD model. + """ + + def __init__( + self, + opts, + in_channels: int, + n_classes: Optional[int] = 1, + inner_dim: Optional[int] = 256, + output_stride: Optional[int] = 1, + output_size: Optional[int] = 8, + *args, + **kwargs + ) -> None: + """ + + Args: + opts: command-line arguments + in_channels (int): :math:`C` from an expected input of size :math:`(N, C, H, W)` + n_classes (Optional[int]): Number of classes. Default: 1 + inner_dim: (Optional[int]): Inner dimension of the instance head. Default: 256 + output_stride (Optional[int]): Output stride of the feature map. Output stride is the ratio of input to + the feature map size. Default: 1 + output_size (Optional[int]): Output size of the instances extracted from RoIAlign layer. Default: 8 + """ + super().__init__() + self.roi_align = RoIAlign( + output_size=output_size, + spatial_scale=1.0 / output_stride, + sampling_ratio=2, + aligned=True, + ) + + self.seg_head = nn.Sequential( + TransposeConvLayer( + opts=opts, + in_channels=in_channels, + out_channels=inner_dim, + kernel_size=2, + stride=2, + bias=True, + use_norm=False, + use_act=True, + auto_padding=False, + padding=0, + output_padding=0, + ), + ConvLayer( + opts=opts, + in_channels=inner_dim, + out_channels=n_classes, + kernel_size=1, + stride=1, + use_norm=False, + use_act=False, + bias=True, + ), + ) + self.inner_channels = inner_dim + self.in_channels = in_channels + self.mask_classes = n_classes + self.reset_parameters() + + def __repr__(self) -> str: + return "{}(in_channels={}, up_out_channels={}, n_classes={})".format( + self.__class__.__name__, + self.in_channels, + self.inner_channels, + self.mask_classes, + ) + + def reset_parameters(self) -> None: + for layer in self.modules(): + if isinstance(layer, (nn.Conv2d, nn.ConvTranspose2d)): + initialize_conv_layer(module=layer, init_method="kaiming_normal") + + def forward(self, x: Tensor, boxes: Tensor, *args, **kwargs) -> Tensor: + rois = self.roi_align(x, boxes) + rois = self.seg_head(rois) + return rois + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + input, params, macs = module_profile(module=self.seg_head, x=input) + return input, params, macs diff --git a/Adaptive Frequency Filters/affnet/modules/swin_transformer_block.py b/Adaptive Frequency Filters/affnet/modules/swin_transformer_block.py new file mode 100644 index 0000000..9e7329f --- /dev/null +++ b/Adaptive Frequency Filters/affnet/modules/swin_transformer_block.py @@ -0,0 +1,429 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +import torch +from torch.nn import functional as F +from typing import List, Optional, Tuple + +from ..layers import ( + get_normalization_layer, + LinearLayer, + get_activation_fn, + Dropout, + StochasticDepth, +) +from ..modules import BaseModule + + +""" +Most of the functions and classes below are heavily borrowed from torchvision https://github.com/pytorch/vision +""" + + +def _patch_merging_pad(x): + H, W, _ = x.shape[-3:] + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + return x + + +class Permute(BaseModule): + """This module returns a view of the tensor input with its dimensions permuted. + Args: + dims (List[int]): The desired ordering of dimensions + """ + + def __init__(self, dims: List[int]): + super().__init__() + self.dims = dims + + def forward(self, x: Tensor) -> Tensor: + return torch.permute(x, self.dims) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}(dims={self.dims})" + return s + + +class PatchMerging(BaseModule): + """Patch Merging Layer. + Args: + dim (int): Number of input channels. + norm_layer (str): Normalization layer name. + strided (Optional[bool]): Down-sample the input by a factor of 2. Default is True. + """ + + def __init__(self, opts, dim: int, norm_layer: str, strided: Optional[bool] = True): + super().__init__() + self.dim = dim + self.reduction = LinearLayer( + in_features=4 * dim, out_features=2 * dim, bias=False + ) + self.norm = get_normalization_layer( + opts=opts, norm_type=norm_layer, num_features=4 * dim + ) + self.strided = strided + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: + """ + Args: + x (Tensor): input tensor with expected layout of [..., H, W, C] + Returns: + Tensor with layout of [..., H/2, W/2, 2*C] + """ + x = _patch_merging_pad(x) + + if self.strided: + x0 = x[..., 0::2, 0::2, :] # ... H/s W/s C + x1 = x[..., 1::2, 0::2, :] # ... H/s W/s C + x2 = x[..., 0::2, 1::2, :] # ... H/s W/s C + x3 = x[..., 1::2, 1::2, :] # ... H/s W/s C + x = torch.cat([x0, x1, x2, x3], -1) # ... H/s W/s 4*C + else: + x = torch.cat([x, x, x, x], -1) # H W 4*C + + x = self.norm(x) + x = self.reduction(x) # ... H/2 W/2 2*C + return x + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}(dim={self.dim})" + return s + + +def shifted_window_attention( + input: Tensor, + qkv_weight: Tensor, + proj_weight: Tensor, + relative_position_bias: Tensor, + window_size: List[int], + num_heads: int, + shift_size: List[int], + attention_dropout: float = 0.0, + dropout: float = 0.0, + qkv_bias: Optional[Tensor] = None, + proj_bias: Optional[Tensor] = None, +): + """ + Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + input (Tensor[N, H, W, C]): The input tensor or 4-dimensions. + qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. + proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection. + relative_position_bias (Tensor): The learned relative position bias added to attention. + window_size (List[int]): Window size. + num_heads (int): Number of attention heads. + shift_size (List[int]): Shift size for shifted window attention. + attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. + dropout (float): Dropout ratio of output. Default: 0.0. + qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. + proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. + Returns: + Tensor[N, H, W, C]: The output tensor after shifted window attention. + """ + B, H, W, C = input.shape + # pad feature maps to multiples of window size + pad_r = (window_size[1] - W % window_size[1]) % window_size[1] + pad_b = (window_size[0] - H % window_size[0]) % window_size[0] + x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) + _, pad_H, pad_W, _ = x.shape + + shift_size = shift_size.copy() + # If window size is larger than feature size, there is no need to shift window + if window_size[0] >= pad_H: + shift_size[0] = 0 + if window_size[1] >= pad_W: + shift_size[1] = 0 + + # cyclic shift + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) + + # partition windows + num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1]) + x = x.view( + B, + pad_H // window_size[0], + window_size[0], + pad_W // window_size[1], + window_size[1], + C, + ) + x = x.permute(0, 1, 3, 2, 4, 5).reshape( + B * num_windows, window_size[0] * window_size[1], C + ) # B*nW, Ws*Ws, C + + # multi-head attention + qkv = F.linear(x, qkv_weight, qkv_bias) + qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute( + 2, 0, 3, 1, 4 + ) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * (C // num_heads) ** -0.5 + attn = q.matmul(k.transpose(-2, -1)) + # add relative position bias + attn = attn + relative_position_bias + + if sum(shift_size) > 0: + # generate attention mask + attn_mask = x.new_zeros((pad_H, pad_W)) + h_slices = ( + (0, -window_size[0]), + (-window_size[0], -shift_size[0]), + (-shift_size[0], None), + ) + w_slices = ( + (0, -window_size[1]), + (-window_size[1], -shift_size[1]), + (-shift_size[1], None), + ) + count = 0 + for h in h_slices: + for w in w_slices: + attn_mask[h[0] : h[1], w[0] : w[1]] = count + count += 1 + attn_mask = attn_mask.view( + pad_H // window_size[0], + window_size[0], + pad_W // window_size[1], + window_size[1], + ) + attn_mask = attn_mask.permute(0, 2, 1, 3).reshape( + num_windows, window_size[0] * window_size[1] + ) + attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) + attn = attn.view( + x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1) + ) + attn = attn + attn_mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, num_heads, x.size(1), x.size(1)) + + attn = F.softmax(attn, dim=-1) + attn = F.dropout(attn, p=attention_dropout) + + x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C) + x = F.linear(x, proj_weight, proj_bias) + x = F.dropout(x, p=dropout) + + # reverse windows + x = x.view( + B, + pad_H // window_size[0], + pad_W // window_size[1], + window_size[0], + window_size[1], + C, + ) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C) + + # reverse cyclic shift + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) + + # unpad features + x = x[:, :H, :W, :].contiguous() + return x + + +class ShiftedWindowAttention(BaseModule): + """ + See :func:`shifted_window_attention`. + """ + + def __init__( + self, + dim: int, + window_size: List[int], + shift_size: List[int], + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + attention_dropout: float = 0.0, + dropout: float = 0.0, + ): + super().__init__() + if len(window_size) != 2 or len(shift_size) != 2: + raise ValueError("window_size and shift_size must be of length 2") + self.window_size = window_size + self.shift_size = shift_size + self.num_heads = num_heads + self.attention_dropout = attention_dropout + self.dropout = dropout + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack( + torch.meshgrid(coords_h, coords_w, indexing="ij") + ) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + + self.embed_dim = dim + + def __repr__(self) -> str: + return "{}(embed_dim={}, window_size={}, shift_size={}, num_heads={}, dropout={}, attn_dropout={}, dropout={})".format( + self.__class__.__name__, + self.embed_dim, + self.window_size, + self.shift_size, + self.num_heads, + self.attention_dropout, + self.dropout, + ) + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: + """ + Args: + x (Tensor): Tensor with layout of [B, H, W, C] + Returns: + Tensor with same layout as input, i.e. [B, H, W, C] + """ + + N = self.window_size[0] * self.window_size[1] + relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index] + relative_position_bias = relative_position_bias.view(N, N, -1) + relative_position_bias = ( + relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) + ) + + return shifted_window_attention( + x, + self.qkv.weight, + self.proj.weight, + relative_position_bias, + self.window_size, + self.num_heads, + shift_size=self.shift_size, + attention_dropout=self.attention_dropout, + dropout=self.dropout, + qkv_bias=self.qkv.bias, + proj_bias=self.proj.bias, + ) + + +class SwinTransformerBlock(BaseModule): + """ + Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (List[int]): Window size. + shift_size (List[int]): Shift size for shifted window attention. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention + """ + + def __init__( + self, + opts, + embed_dim: int, + num_heads: int, + window_size: List[int], + shift_size: List[int], + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attn_dropout: Optional[float] = 0.0, + ffn_dropout: Optional[float] = 0.0, + stochastic_depth_prob: float = 0.0, + norm_layer: Optional[str] = "layer_norm", + ): + super().__init__() + + attn_unit = ShiftedWindowAttention( + embed_dim, + window_size, + shift_size, + num_heads, + attention_dropout=attn_dropout, + dropout=dropout, + ) + self.attn = nn.Sequential( + get_normalization_layer( + opts=opts, norm_type=norm_layer, num_features=embed_dim + ), + attn_unit, + Dropout(p=dropout), + ) + + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + ffn_latent_dim = int(embed_dim * mlp_ratio) + act_name = self.build_act_layer(opts=opts) + self.mlp = nn.Sequential( + get_normalization_layer( + opts=opts, norm_type=norm_layer, num_features=embed_dim + ), + LinearLayer(in_features=embed_dim, out_features=ffn_latent_dim, bias=True), + act_name, + Dropout(p=ffn_dropout), + LinearLayer(in_features=ffn_latent_dim, out_features=embed_dim, bias=True), + Dropout(p=dropout), + ) + self.embed_dim = embed_dim + self.ffn_dim = ffn_latent_dim + self.ffn_dropout = ffn_dropout + self.std_dropout = dropout + self.attn_fn_name = attn_unit.__class__.__name__ + self.act_fn_name = act_name.__class__.__name__ + self.norm_type = norm_layer + + @staticmethod + def build_act_layer(opts) -> nn.Module: + act_type = getattr(opts, "model.activation.name", "gelu") + neg_slope = getattr(opts, "model.activation.neg_slope", 0.1) + inplace = getattr(opts, "model.activation.inplace", False) + act_layer = get_activation_fn( + act_type=act_type, + inplace=inplace, + negative_slope=neg_slope, + num_parameters=1, + ) + return act_layer + + def __repr__(self) -> str: + return "{}(embed_dim={}, ffn_dim={}, dropout={}, ffn_dropout={}, attn_fn={}, act_fn={}, norm_fn={})".format( + self.__class__.__name__, + self.embed_dim, + self.ffn_dim, + self.std_dropout, + self.ffn_dropout, + self.attn_fn_name, + self.act_fn_name, + self.norm_type, + ) + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: + x = x + self.stochastic_depth(self.attn(x)) + x = x + self.stochastic_depth(self.mlp(x)) + return x diff --git a/Adaptive Frequency Filters/affnet/modules/transformer.py b/Adaptive Frequency Filters/affnet/modules/transformer.py new file mode 100644 index 0000000..625bc3c --- /dev/null +++ b/Adaptive Frequency Filters/affnet/modules/transformer.py @@ -0,0 +1,299 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn, Tensor +from typing import Optional, Union, Tuple + +from ..layers import ( + get_normalization_layer, + LinearLayer, + get_activation_fn, + ConvLayer, + MultiHeadAttention, + Dropout, + SingleHeadAttention, + LinearSelfAttention, +) + +from ..modules import BaseModule +from ..misc.profiler import module_profile + + +class TransformerEncoder(BaseModule): + """ + This class defines the pre-norm `Transformer encoder `_ + Args: + opts: command line arguments + embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})` + ffn_latent_dim (int): Inner dimension of the FFN + num_heads (Optional[int]) : Number of heads in multi-head attention. Default: 8 + attn_dropout (Optional[float]): Dropout rate for attention in multi-head attention. Default: 0.0 + dropout (Optional[float]): Dropout rate. Default: 0.0 + ffn_dropout (Optional[float]): Dropout between FFN layers. Default: 0.0 + transformer_norm_layer (Optional[str]): Normalization layer. Default: layer_norm + + Shape: + - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches, + and :math:`C_{in}` is input embedding dim + - Output: same shape as the input + """ + + def __init__( + self, + opts, + embed_dim: int, + ffn_latent_dim: int, + num_heads: Optional[int] = 8, + attn_dropout: Optional[float] = 0.0, + dropout: Optional[float] = 0.0, + ffn_dropout: Optional[float] = 0.0, + transformer_norm_layer: Optional[str] = "layer_norm", + *args, + **kwargs + ) -> None: + + super().__init__() + + attn_unit = SingleHeadAttention( + embed_dim=embed_dim, attn_dropout=attn_dropout, bias=True + ) + if num_heads > 1: + attn_unit = MultiHeadAttention( + embed_dim, + num_heads, + attn_dropout=attn_dropout, + bias=True, + coreml_compatible=getattr( + opts, "common.enable_coreml_compatible_module", False + ), + ) + + self.pre_norm_mha = nn.Sequential( + get_normalization_layer( + opts=opts, norm_type=transformer_norm_layer, num_features=embed_dim + ), + attn_unit, + Dropout(p=dropout), + ) + + act_name = self.build_act_layer(opts=opts) + self.pre_norm_ffn = nn.Sequential( + get_normalization_layer( + opts=opts, norm_type=transformer_norm_layer, num_features=embed_dim + ), + LinearLayer(in_features=embed_dim, out_features=ffn_latent_dim, bias=True), + act_name, + Dropout(p=ffn_dropout), + LinearLayer(in_features=ffn_latent_dim, out_features=embed_dim, bias=True), + Dropout(p=dropout), + ) + self.embed_dim = embed_dim + self.ffn_dim = ffn_latent_dim + self.ffn_dropout = ffn_dropout + self.std_dropout = dropout + self.attn_fn_name = attn_unit.__class__.__name__ + self.act_fn_name = act_name.__class__.__name__ + self.norm_type = transformer_norm_layer + + @staticmethod + def build_act_layer(opts) -> nn.Module: + act_type = getattr(opts, "model.activation.name", "relu") + neg_slope = getattr(opts, "model.activation.neg_slope", 0.1) + inplace = getattr(opts, "model.activation.inplace", False) + act_layer = get_activation_fn( + act_type=act_type, + inplace=inplace, + negative_slope=neg_slope, + num_parameters=1, + ) + return act_layer + + def __repr__(self) -> str: + return "{}(embed_dim={}, ffn_dim={}, dropout={}, ffn_dropout={}, attn_fn={}, act_fn={}, norm_fn={})".format( + self.__class__.__name__, + self.embed_dim, + self.ffn_dim, + self.std_dropout, + self.ffn_dropout, + self.attn_fn_name, + self.act_fn_name, + self.norm_type, + ) + + def forward( + self, + x: Tensor, + x_prev: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + *args, + **kwargs + ) -> Tensor: + + # Multi-head attention + res = x + x = self.pre_norm_mha[0](x) # norm + x = self.pre_norm_mha[1]( + x_q=x, + x_kv=x_prev, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + *args, + **kwargs + ) # mha + x = self.pre_norm_mha[2](x) # dropout + x = x + res + + # Feed forward network + x = x + self.pre_norm_ffn(x) + return x + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + b_sz, seq_len = input.shape[:2] + + out, p_mha, m_mha = module_profile(module=self.pre_norm_mha, x=input) + + out, p_ffn, m_ffn = module_profile(module=self.pre_norm_ffn, x=input) + m_ffn = m_ffn * b_sz * seq_len + + macs = m_mha + m_ffn + params = p_mha + p_ffn + + return input, params, macs + + +class LinearAttnFFN(BaseModule): + """ + This class defines the pre-norm transformer encoder with linear self-attention in `MobileViTv2 `_ paper + Args: + opts: command line arguments + embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, P, N)` + ffn_latent_dim (int): Inner dimension of the FFN + attn_dropout (Optional[float]): Dropout rate for attention in multi-head attention. Default: 0.0 + dropout (Optional[float]): Dropout rate. Default: 0.0 + ffn_dropout (Optional[float]): Dropout between FFN layers. Default: 0.0 + norm_layer (Optional[str]): Normalization layer. Default: layer_norm_2d + + Shape: + - Input: :math:`(B, C_{in}, P, N)` where :math:`B` is batch size, :math:`C_{in}` is input embedding dim, + :math:`P` is number of pixels in a patch, and :math:`N` is number of patches, + - Output: same shape as the input + """ + + def __init__( + self, + opts, + embed_dim: int, + ffn_latent_dim: int, + attn_dropout: Optional[float] = 0.0, + dropout: Optional[float] = 0.1, + ffn_dropout: Optional[float] = 0.0, + norm_layer: Optional[str] = "layer_norm_2d", + *args, + **kwargs + ) -> None: + super().__init__() + attn_unit = LinearSelfAttention( + opts, embed_dim=embed_dim, attn_dropout=attn_dropout, bias=True + ) + + self.pre_norm_attn = nn.Sequential( + get_normalization_layer( + opts=opts, norm_type=norm_layer, num_features=embed_dim + ), + attn_unit, + Dropout(p=dropout), + ) + + self.pre_norm_ffn = nn.Sequential( + get_normalization_layer( + opts=opts, norm_type=norm_layer, num_features=embed_dim + ), + ConvLayer( + opts=opts, + in_channels=embed_dim, + out_channels=ffn_latent_dim, + kernel_size=1, + stride=1, + bias=True, + use_norm=False, + use_act=True, + ), + Dropout(p=ffn_dropout), + ConvLayer( + opts=opts, + in_channels=ffn_latent_dim, + out_channels=embed_dim, + kernel_size=1, + stride=1, + bias=True, + use_norm=False, + use_act=False, + ), + Dropout(p=dropout), + ) + + self.embed_dim = embed_dim + self.ffn_dim = ffn_latent_dim + self.ffn_dropout = ffn_dropout + self.std_dropout = dropout + self.attn_fn_name = attn_unit.__repr__() + self.norm_name = norm_layer + + @staticmethod + def build_act_layer(opts) -> nn.Module: + act_type = getattr(opts, "model.activation.name", "relu") + neg_slope = getattr(opts, "model.activation.neg_slope", 0.1) + inplace = getattr(opts, "model.activation.inplace", False) + act_layer = get_activation_fn( + act_type=act_type, + inplace=inplace, + negative_slope=neg_slope, + num_parameters=1, + ) + return act_layer + + def __repr__(self) -> str: + return "{}(embed_dim={}, ffn_dim={}, dropout={}, ffn_dropout={}, attn_fn={}, norm_layer={})".format( + self.__class__.__name__, + self.embed_dim, + self.ffn_dim, + self.std_dropout, + self.ffn_dropout, + self.attn_fn_name, + self.norm_name, + ) + + def forward( + self, x: Tensor, x_prev: Optional[Tensor] = None, *args, **kwargs + ) -> Tensor: + if x_prev is None: + # self-attention + x = x + self.pre_norm_attn(x) + else: + # cross-attention + res = x + x = self.pre_norm_attn[0](x) # norm + x = self.pre_norm_attn[1](x, x_prev) # attn + x = self.pre_norm_attn[2](x) # drop + x = x + res # residual + + # Feed forward network + x = x + self.pre_norm_ffn(x) + return x + + def profile_module( + self, input: Tensor, *args, **kwargs + ) -> Tuple[Tensor, float, float]: + out, p_mha, m_mha = module_profile(module=self.pre_norm_attn, x=input) + out, p_ffn, m_ffn = module_profile(module=self.pre_norm_ffn, x=input) + + macs = m_mha + m_ffn + params = p_mha + p_ffn + + return input, params, macs diff --git a/Adaptive Frequency Filters/affnet/neural_augmentor/__init__.py b/Adaptive Frequency Filters/affnet/neural_augmentor/__init__.py new file mode 100644 index 0000000..4e52364 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/neural_augmentor/__init__.py @@ -0,0 +1,15 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse + +from .neural_aug import build_neural_augmentor, BaseNeuralAugmentor + + +def arguments_neural_augmentor( + parser: argparse.ArgumentParser, +) -> argparse.ArgumentParser: + return BaseNeuralAugmentor.add_arguments(parser=parser) diff --git a/Adaptive Frequency Filters/affnet/neural_augmentor/neural_aug.py b/Adaptive Frequency Filters/affnet/neural_augmentor/neural_aug.py new file mode 100644 index 0000000..425014b --- /dev/null +++ b/Adaptive Frequency Filters/affnet/neural_augmentor/neural_aug.py @@ -0,0 +1,320 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import random +import torch +from torch import nn, Tensor +from typing import Optional, List +import argparse + +from utils import logger + +from affnet.neural_augmentor.utils.neural_aug_utils import ( + UniformSampler, + random_noise, + random_contrast, + random_brightness, + Clip, + FixedSampler, +) +from affnet import parameter_list + + +_distribution_tuple = (UniformSampler,) + + +class BaseNeuralAugmentor(nn.Module): + """ + Base class for `neural (or range) augmentation `_ + """ + + def __init__(self, opts, *args, **kwargs): + super().__init__() + self.opts = opts + + self.lr_multiplier = getattr( + opts, "model.learn_augmentation.lr_multiplier", 1.0 + ) + + # Set variables corresponding to different transforms to None. + # We will override them in child classes with learnable versions + self.brightness = None + self.contrast = None + self.noise = None + + self.aug_fns = [] + + def _is_valid_aug_fn_list(self, aug_fns): + if self.training: + if len(aug_fns) == 0: + logger.error( + "{} needs at least one learnable function.".format( + self.__class__.__name__ + ) + ) + + def get_trainable_parameters( + self, + weight_decay: Optional[float] = 0.0, + no_decay_bn_filter_bias: Optional[bool] = False, + *args, + **kwargs + ): + """Get trainable parameters""" + param_list = parameter_list( + named_parameters=self.named_parameters, + weight_decay=weight_decay, + no_decay_bn_filter_bias=no_decay_bn_filter_bias, + ) + return param_list, [self.lr_multiplier] * len(param_list) + + def __repr__(self): + aug_str = "{}(".format(self.__class__.__name__) + + if self.brightness is not None: + aug_str += "\n\tBrightness={}, ".format( + self.brightness.data.shape + if isinstance(self.brightness, nn.Parameter) + else self.brightness + ) + + if self.contrast is not None: + aug_str += "\n\tContrast={}, ".format( + self.contrast.data.shape + if isinstance(self.contrast, nn.Parameter) + else self.contrast + ) + + if self.noise is not None: + aug_str += "\n\tNoise={}, ".format( + self.noise.data.shape + if isinstance(self.noise, nn.Parameter) + else self.noise + ) + + aug_str += self.extra_repr() + aug_str += ")" + return aug_str + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + """Add model-specific arguments""" + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--model.learn-augmentation.mode", + type=str, + default=None, + choices=["basic", "distribution"], + help="Neural augmentation mode", + ) + + group.add_argument( + "--model.learn-augmentation.brightness", + action="store_true", + help="Learn parameters for brightness", + ) + + group.add_argument( + "--model.learn-augmentation.contrast", + action="store_true", + help="Learn parameters for contrast", + ) + + group.add_argument( + "--model.learn-augmentation.noise", + action="store_true", + help="Learn parameters for noise", + ) + + # LR multiplier + group.add_argument( + "--model.learn-augmentation.lr-multiplier", + type=float, + default=1.0, + help="LR multiplier for neural aug parameters", + ) + + return parser + + def _build_aug_fns(self, opts) -> List: + raise NotImplementedError + + def _apply_brightness(self, x: Tensor, *args, **kwargs) -> Tensor: + """ + Apply brightness augmentation function with learnable parameters. + """ + # self._check_brightness_bounds() + x_shape = [*x.shape] + x_shape[1:] = [1] * (len(x_shape) - 1) + if isinstance(self.brightness, nn.Parameter): + # learning a fixed number of parameters + magnitude = self.brightness + elif isinstance(self.brightness, _distribution_tuple): + # learning a distribution range from which parameter is sampled. + magnitude = self.brightness(x_shape, device=x.device, data_type=x.dtype) + else: + raise NotImplementedError + return random_brightness(x, magnitude, *args, **kwargs) + + def _apply_contrast(self, x: Tensor, *args, **kwargs) -> Tensor: + """ + Apply contrast augmentation function with learnable parameters. + """ + # self._check_contrast_bounds() + x_shape = [*x.shape] + x_shape[1:] = [1] * (len(x_shape) - 1) + + if isinstance(self.contrast, nn.Parameter): + # learning a fixed number of parameters + magnitude = self.contrast + elif isinstance(self.contrast, _distribution_tuple): + # learning a distribution range from which parameter is sampled. + magnitude = self.contrast(x_shape, device=x.device, data_type=x.dtype) + else: + raise NotImplementedError + return random_contrast(x, magnitude, *args, *kwargs) + + def _apply_noise(self, x: Tensor, *args, **kwargs) -> Tensor: + # self._check_noise_bounds() + x_shape = [*x.shape] + x_shape[1:] = [1] * (len(x_shape) - 1) + + if isinstance(self.noise, nn.Parameter): + # learning a fixed number of parameters + variance = self.noise + elif isinstance(self.noise, _distribution_tuple): + # learning a distribution range from which parameter is sampled. + variance = self.noise(x_shape, device=x.device, data_type=x.dtype) + else: + raise NotImplementedError + return random_noise(x, variance, *args, *kwargs) + + def forward(self, x: Tensor, *args, **kwargs) -> Tensor: + batch_size, in_channels, in_height, in_width = x.shape + + # Randomly apply augmentation to 50% of the samples + n_aug_samples = max(1, (batch_size // 2)) + + # shuffle the order of augmentations + random.shuffle(self.aug_fns) + + for aug_fn in self.aug_fns: + # select 50% samples for augmentation + sample_ids = torch.randperm( + n=batch_size, dtype=torch.long, device=x.device + )[:n_aug_samples] + x_aug = torch.index_select(x, dim=0, index=sample_ids) + # apply augmentation + x_aug = aug_fn(x=x_aug) + # copy augmented samples to tensor + x = torch.index_copy(x, dim=0, source=x_aug, index=sample_ids) + + # clip the values so that they are between 0 and 1 + x = torch.clip(x, min=0.0, max=1.0) + return x + + +class BasicNeuralAugmentor(BaseNeuralAugmentor): + """ + Basic neural augmentation. This class learns per-channel augmentation parameters + and apply the same parameter to all images in a batch. + + See `neural (or range) augmentation `_ paper for details. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts=opts, *args, **kwargs) + aug_fns = self._build_aug_fns(opts=opts) + + self._is_valid_aug_fn_list(aug_fns) + + self.aug_fns = aug_fns + + def _build_aug_fns(self, opts) -> List: + aug_fns = [] + if getattr(opts, "model.learn_augmentation.brightness", False): + self.brightness = FixedSampler( + value=1.0, clip_fn=Clip(min_val=0.1, max_val=10.0) + ) + aug_fns.append(self._apply_brightness) + + if getattr(opts, "model.learn_augmentation.contrast", False): + self.contrast = FixedSampler( + value=1.0, clip_fn=Clip(min_val=0.1, max_val=10.0) + ) + aug_fns.append(self._apply_contrast) + + if getattr(opts, "model.learn_augmentation.noise", False): + self.noise = FixedSampler(value=0.0, clip_fn=Clip(min_val=0.0, max_val=1.0)) + aug_fns.append(self._apply_noise) + + return aug_fns + + +class DistributionNeuralAugmentor(BaseNeuralAugmentor): + """ + Distribution-based neural (or range) augmentation. This class samples the augmentation parameters + from a specified distribution with learnable range. + + See `neural (or range) augmentation `_ paper for details. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts=opts, *args, **kwargs) + + aug_fns = self._build_aug_fns_with_uniform_dist(opts=opts) + self._is_valid_aug_fn_list(aug_fns) + self.aug_fns = aug_fns + + def _build_aug_fns_with_uniform_dist(self, opts) -> List: + # need to define the learnable parameters in a way that are compatible with bucketing + aug_fns = [] + if getattr(opts, "model.learn_augmentation.brightness", False): + self.brightness = UniformSampler( + low=0.5, + high=1.5, + min_fn=Clip(min_val=0.1, max_val=0.9), + max_fn=Clip(min_val=1.1, max_val=10.0), + ) + aug_fns.append(self._apply_brightness) + + if getattr(opts, "model.learn_augmentation.contrast", False): + self.contrast = UniformSampler( + low=0.5, + high=1.5, + min_fn=Clip(min_val=0.1, max_val=0.9), + max_fn=Clip(min_val=1.1, max_val=10.0), + ) + aug_fns.append(self._apply_contrast) + + if getattr(opts, "model.learn_augmentation.noise", False): + self.noise = UniformSampler( + low=0.0, + high=0.1, + min_fn=Clip(min_val=0.0, max_val=0.00005), + max_fn=Clip(min_val=0.0001, max_val=1.0), + ) + aug_fns.append(self._apply_noise) + + return aug_fns + + +def build_neural_augmentor(opts, *args, **kwargs): + mode = getattr(opts, "model.learn_augmentation.mode", None) + + if mode is None: + mode = "none" + + mode = mode.lower() + if mode == "distribution": + return DistributionNeuralAugmentor(opts=opts, *args, **kwargs) + elif mode == "basic": + return BasicNeuralAugmentor(opts=opts, *args, **kwargs) + else: + return None diff --git a/Adaptive Frequency Filters/affnet/neural_augmentor/utils/__init__.py b/Adaptive Frequency Filters/affnet/neural_augmentor/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Adaptive Frequency Filters/affnet/neural_augmentor/utils/neural_aug_utils.py b/Adaptive Frequency Filters/affnet/neural_augmentor/utils/neural_aug_utils.py new file mode 100644 index 0000000..1dfe42a --- /dev/null +++ b/Adaptive Frequency Filters/affnet/neural_augmentor/utils/neural_aug_utils.py @@ -0,0 +1,141 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import Tensor, nn +from typing import Optional, Any + + +class Clip(nn.Module): + def __init__( + self, + min_val: float, + max_val: float, + hard_clip: Optional[bool] = False, + *args, + **kwargs, + ) -> None: + super().__init__() + self.min_val = min_val + self.max_val = max_val + self.hard_clip = hard_clip + + def forward(self, x: Any) -> Any: + if self.hard_clip: + with torch.no_grad(): + return x.clamp_(min=self.min_val, max=self.max_val) + else: + return (torch.sigmoid(x) * (self.max_val - self.min_val)) + self.min_val + + def __repr__(self): + return "{}(min={}, max={}, clipping={})".format( + self.__class__.__name__, + self.min_val, + self.max_val, + "hard" if self.hard_clip else "soft", + ) + + +class Identity(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, x: Any) -> Any: + return x + + +class FixedSampler(nn.Module): + def __init__( + self, + value: float, + clip_fn: Optional[nn.Module] = Identity(), + *args, + **kwargs, + ): + super().__init__() + self._value = nn.Parameter(torch.FloatTensor(1, 3, 1, 1).fill_(value)) + self.clip_fn = clip_fn + + def forward( + self, sample_shape=(), data_type=torch.float, device=torch.device("cpu") + ) -> Tensor: + # sample values from uniform distribution + return self.clip_fn(self._value) + + def __repr__(self): + return "{}(clip_fn={})".format( + self.__class__.__name__, + self.clip_fn, + ) + + +class UniformSampler(nn.Module): + def __init__( + self, + low: float, + high: float, + min_fn: Optional[nn.Module] = Identity(), + max_fn: Optional[nn.Module] = Identity(), + *args, + **kwargs, + ): + super().__init__() + self._low = nn.Parameter(torch.tensor(low, dtype=torch.float)) + self._high = nn.Parameter(torch.tensor(high, dtype=torch.float)) + self.min_fn = min_fn + self.max_fn = max_fn + + def forward( + self, sample_shape=(), data_type=torch.float, device=torch.device("cpu") + ) -> Tensor: + # sample values from uniform distribution + rand_tensor = torch.rand(sample_shape, dtype=data_type, device=device) + return self.low + rand_tensor * (self.high - self.low) + + @property + def high(self): + return self.max_fn(self._high) + + @property + def low(self): + return self.min_fn(self._low) + + def __repr__(self): + return "{}(min_fn={}, max_fn={})".format( + self.__class__.__name__, + self.min_fn, + self.max_fn, + ) + + +def random_noise(x: Tensor, variance: Tensor, *args, **kwargs) -> Tensor: + """Apply random noise sampled.""" + noise = torch.randn_like(x) * variance + x = x + noise + return x + + +def random_contrast(x: Tensor, magnitude: Tensor, *args, **kwargs) -> Tensor: + # compute per-channel mean + per_channel_mean = torch.mean(x, dim=[-1, -2], keepdim=True) + + # contrast can be written as + # (1 - contrast_factor) * per_channel_mean + img * contrast_factor + x = ((1.0 - magnitude) * per_channel_mean) + (x * magnitude) + return x + + +def random_brightness(x: Tensor, magnitude: Tensor, *args, **kwargs) -> Tensor: + """ + Brightness function. + """ + x = x * magnitude + return x + + +def identity(x: Tensor, *args, **kwargs) -> Tensor: + """Identity function""" + return x diff --git a/Adaptive Frequency Filters/affnet/text_encoders/__init__.py b/Adaptive Frequency Filters/affnet/text_encoders/__init__.py new file mode 100644 index 0000000..1bf1563 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/text_encoders/__init__.py @@ -0,0 +1,92 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import importlib +import argparse +from typing import Optional + +from utils import logger + +from .base_text_encoder import BaseTextEncoder + + +TEXT_ENCODER_REGISTRY = {} + + +def register_text_encoder(name): + # register the text_encoder class + def register_text_encoder_class(cls): + if name in TEXT_ENCODER_REGISTRY: + raise ValueError( + "Cannot register duplicate text_encoder class ({})".format(name) + ) + + if not issubclass(cls, BaseTextEncoder): + raise ValueError( + "Text encoder class ({}: {}) must extend BaseTextEncoder".format( + name, cls.__name__ + ) + ) + + TEXT_ENCODER_REGISTRY[name] = cls + return cls + + return register_text_encoder_class + + +def arguments_text_encoder(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + # add arguments for text_encoder + parser = BaseTextEncoder.add_arguments(parser) + + # add augmentation specific arguments + for k, v in TEXT_ENCODER_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + return parser + + +def supported_text_encoder_str(text_encoder_name: Optional[str] = None) -> None: + """Helper utility to print supported text_encoder names in case specified text_encoder + name is not part of the implemented text encoders. + """ + supp_list = list(TEXT_ENCODER_REGISTRY.keys()) + if text_encoder_name is None: + supp_str = "Text encoder name can't be None. \n Supported text encoders are:" + else: + supp_str = "Text encoder ({}) is not yet supported. \n Supported text encoders are:".format( + text_encoder_name + ) + for t_name in supp_list: + supp_str += "\n\t{}".format(t_name) + logger.error(supp_str + "\n") + + +def build_text_encoder(opts, projection_dim: int, *args, **kwargs) -> BaseTextEncoder: + """Helper function to build the text encoder""" + text_encoder_name = getattr(opts, "model.text.name", None) + if text_encoder_name is None: + supported_text_encoder_str(text_encoder_name) + + if text_encoder_name in list(TEXT_ENCODER_REGISTRY.keys()): + return TEXT_ENCODER_REGISTRY[text_encoder_name]( + opts, projection_dim, *args, **kwargs + ) + else: + supported_text_encoder_str(text_encoder_name) + + +# automatically import the text encoders +text_encoder_dir = os.path.dirname(__file__) + +for file in os.listdir(text_encoder_dir): + path = os.path.join(text_encoder_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + text_encoder_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("affnet.text_encoders." + text_encoder_name) diff --git a/Adaptive Frequency Filters/affnet/text_encoders/base_text_encoder.py b/Adaptive Frequency Filters/affnet/text_encoders/base_text_encoder.py new file mode 100644 index 0000000..c88e4b3 --- /dev/null +++ b/Adaptive Frequency Filters/affnet/text_encoders/base_text_encoder.py @@ -0,0 +1,111 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +import argparse +from typing import Optional, Tuple, Dict, Any + +from utils import logger +from utils.ddp_utils import is_master + +from affnet import parameter_list +from affnet.layers import norm_layers_tuple +from affnet.misc.init_utils import initialize_weights + + +class BaseTextEncoder(nn.Module): + """Base class for text encoder""" + + def __init__(self, opts, projection_dim: int, *args, **kwargs) -> None: + is_master_node = is_master(opts) + vocab_size = getattr(opts, "dataset.text_vocab_size", None) + if getattr(opts, "common.debug_mode", False): + vocab_size = 100 + if vocab_size is None and is_master_node: + logger.error( + "Vocabulary size can't be None or -1 in {}. Got: {}".format( + self.__class__.__name__, vocab_size + ) + ) + + super(BaseTextEncoder, self).__init__() + self.opts = opts + self.projection_dim = projection_dim + self.is_master_node = is_master_node + self.vocab_size = vocab_size + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Add model specific arguments""" + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--model.text.name", + type=str, + default=None, + help="Name of the text encoder", + ) + + return parser + + def reset_parameters(self): + """Initialize model weights""" + initialize_weights(opts=self.opts, modules=self.modules()) + + def get_trainable_parameters( + self, + weight_decay: Optional[float] = 0.0, + no_decay_bn_filter_bias: Optional[bool] = False, + *args, + **kwargs + ): + + param_list = parameter_list( + named_parameters=self.named_parameters, + weight_decay=weight_decay, + no_decay_bn_filter_bias=no_decay_bn_filter_bias, + *args, + **kwargs + ) + return param_list, [1.0] * len(param_list) + + def profile_model(self, input: Tensor) -> Optional[Tuple[Tensor, float, float]]: + """ + Child classes must implement this function to compute FLOPs and parameters + """ + raise NotImplementedError + + def freeze_norm_layers(self) -> None: + for m in self.modules(): + if isinstance(m, norm_layers_tuple): + m.eval() + m.weight.requires_grad = False + m.bias.requires_grad = False + m.training = False + + def forward( + self, + text_tokens: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + *args, + **kwargs + ) -> Any: + raise NotImplementedError + + def dummy_input_and_label(self, batch_size: int) -> Dict: + """Create dummy input and labels for CI/CD purposes. Child classes must override it + if functionality is different. + """ + seq_length = 77 + vocab_size = 10 + text_tensor = torch.randint( + low=0, high=vocab_size, size=(batch_size, seq_length) + ).long() + return {"text": text_tensor} diff --git a/Adaptive Frequency Filters/affnet/text_encoders/transformer.py b/Adaptive Frequency Filters/affnet/text_encoders/transformer.py new file mode 100644 index 0000000..a49a0ef --- /dev/null +++ b/Adaptive Frequency Filters/affnet/text_encoders/transformer.py @@ -0,0 +1,515 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse +import math +import torch +from torch import Tensor, nn +from typing import Optional, Sequence, Any +from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint as gradient_checkpoint_fn + +from utils import logger + +from affnet.layers import ( + Embedding, + PositionalEmbedding, + Dropout, + get_normalization_layer, +) +from affnet.modules import TransformerEncoder + +from . import BaseTextEncoder, register_text_encoder + + +@register_text_encoder(name="transformer") +class TextTransformer(BaseTextEncoder): + def __init__(self, opts, projection_dim: int, *args, **kwargs) -> None: + model_dim = getattr(opts, "model.text.transformer.model_dim", 512) + no_scale_embedding = getattr( + opts, "model.text.transformer.no_scale_embedding", False + ) + no_pos_embedding = getattr( + opts, "model.text.transformer.no_pos_embedding", False + ) + embed_dropout = getattr(opts, "model.text.transformer.embed_dropout", 0.0) + dropout = getattr(opts, "model.text.transformer.dropout", 0.0) + attn_dropout = getattr(opts, "model.text.transformer.attn_dropout", 0.0) + ffn_dropout = getattr(opts, "model.text.transformer.ffn_dropout", 0.0) + norm_layer = getattr(opts, "model.text.transformer.norm_layer", None) + + gradient_ckpt = getattr( + opts, "model.text.transformer.gradient_checkpoint", False + ) + + if norm_layer is None: + logger.error( + "Normalization layer can not be None in {}".format( + self.__class__.__name__ + ) + ) + + super().__init__(opts=opts, projection_dim=projection_dim, *args, **kwargs) + + # token embedding layer + padding_index = getattr(opts, "dataset.padding_index", None) + self.embedding_layer = Embedding( + opts=opts, + embedding_dim=model_dim, + padding_idx=padding_index, + num_embeddings=self.vocab_size, + ) + self.embed_scale = 1.0 if no_scale_embedding else model_dim**-0.5 + + context_length = getattr(opts, "dataset.text_context_length", None) + + if getattr(opts, "common.debug_mode", False): + context_length = 77 + + assert context_length is not None, ( + "Context length can't be None. Please set dataset.text_context_length " + "argument in your dataset class" + ) + + self.positional_embedding = ( + None + if no_pos_embedding + else PositionalEmbedding( + opts=opts, + num_embeddings=context_length, + embedding_dim=model_dim, + padding_idx=getattr(opts, "dataset.padding_index", None), + is_learnable=not getattr( + opts, "model.text.transformer.sinusoidal_pos_emb", False + ), + ) + ) + + self.embedding_dropout = Dropout(p=embed_dropout) + + # Transformer layer + + n_transformer_layers = getattr( + opts, "model.text.transformer.n_transformer_layers", 6 + ) + # FFN multipliers for transformer layer + ffn_multipliers = getattr( + opts, "model.text.transformer.ffn_multiplier_per_layer", 4.0 + ) + if isinstance(ffn_multipliers, (float, int)): + ffn_multipliers = [ffn_multipliers] * n_transformer_layers + + if not isinstance(ffn_multipliers, Sequence): + logger.error( + "{} expects FFN multipliers as a list, whose length is the same as number of " + "transformer layers. Got: {}".format( + self.__class__.__name__, type(ffn_multipliers) + ) + ) + elif ( + isinstance(ffn_multipliers, Sequence) + and len(ffn_multipliers) != n_transformer_layers + ): + logger.error( + "We need FFN multiplier for each transformer layer. Got {} ffn multipliers while number of " + "transformer layers = {}".format( + len(ffn_multipliers), n_transformer_layers + ) + ) + ffn_dims = [ + int(math.ceil(model_dim * ffn_mult / 16.0) * 16.0) + for ffn_mult in ffn_multipliers + ] + + # Heads for transformer layers + mha_heads = getattr(opts, "model.text.transformer.n_heads_per_layer", 8) + if isinstance(mha_heads, int): + mha_heads = [mha_heads] * n_transformer_layers + + if not isinstance(mha_heads, Sequence): + logger.error( + "{} expects MHA heads as a list, whose length is the same as number of " + "transformer layers. Got: {}".format( + self.__class__.__name__, type(mha_heads) + ) + ) + elif isinstance(mha_heads, Sequence) and len(mha_heads) != n_transformer_layers: + logger.error( + "{} needs MHA heads for each transformer layer. Got {} mha heads while number of " + "transformer layers = {}".format( + self.__class__.__name__, len(mha_heads), n_transformer_layers + ) + ) + + self.transformer = nn.ModuleList( + [ + TransformerEncoder( + opts=opts, + embed_dim=model_dim, + num_heads=mha_heads[layer_idx], + ffn_latent_dim=ffn_dims[layer_idx], + attn_dropout=attn_dropout, + ffn_dropout=ffn_dropout, + dropout=dropout, + transformer_norm_layer=norm_layer, + ) + for layer_idx in range(n_transformer_layers) + ] + ) + self.final_layer_norm = get_normalization_layer( + opts, num_features=model_dim, norm_type=norm_layer + ) + + self.projection_layer = nn.Parameter( + torch.empty(model_dim, self.projection_dim) + ) + self.model_dim = model_dim + self.reset_parameters_clip_style() + self.gradient_ckpt = gradient_ckpt + self.use_pytorch_mha = False + self.causal_masking = getattr( + opts, "model.text.transformer.causal_masking", False + ) + self.classes_per_split_zero_shot = max( + 1, + int(getattr(opts, "model.text.transformer.classes_per_split_zero_shot", 1)), + ) + + def reset_parameters_clip_style(self): + """This function resets the weights of Transformer model as done in the CLIP paper""" + + # reset the weights of the embedding and positional embedding layers + nn.init.normal_(self.embedding_layer.weight, mean=0.0, std=0.02) + # if self.positional_embedding is not None and not getattr( + # self.opts, "model.text.transformer.sinusoidal_pos_emb", False + # ): + # nn.init.normal_( + # self.positional_embedding.pos_embed.weight, mean=0.0, std=0.01 + # ) + + # compute standard deviation for different linear layers in transformer model + attn_std = self.model_dim**-0.5 + proj_std = attn_std * ((2 * len(self.transformer)) ** -0.5) + fc_std = (2 * self.model_dim) ** -0.5 + + for block in self.transformer: + # multi-head attention QKV projection layer + nn.init.normal_( + block.pre_norm_mha[1].qkv_proj.weight, mean=0.0, std=attn_std + ) + # multi-head attention output projection layer + nn.init.normal_( + block.pre_norm_mha[1].out_proj.weight, mean=0.0, std=proj_std + ) + # FFN expansion layer + nn.init.normal_(block.pre_norm_ffn[1].weight, mean=0.0, std=fc_std) + # FFN reduction layer + nn.init.normal_(block.pre_norm_ffn[4].weight, mean=0.0, std=proj_std) + + nn.init.normal_(self.projection_layer, mean=0.0, std=attn_std) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--model.text.transformer.model-dim", + type=int, + default=512, + help="Model dimension of the transformer model", + ) + + group.add_argument( + "--model.text.transformer.no-scale-embedding", + action="store_true", + help="Do not scale the output of embedding layer in {}".format( + cls.__name__ + ), + ) + + group.add_argument( + "--model.text.transformer.no-pos-embedding", + action="store_true", + help="Do not add positional embeddings to the output of embedding layer in {}".format( + cls.__name__ + ), + ) + + group.add_argument( + "--model.text.transformer.embed-dropout", + type=float, + default=0.0, + help="Dropout in embedding layer", + ) + + # transformer layer parameters + default_layers = 6 + group.add_argument( + "--model.text.transformer.n-transformer-layers", + type=int, + default=default_layers, + help="Number of transformer layers in {}".format(cls.__name__), + ) + group.add_argument( + "--model.text.transformer.n-heads-per-layer", + type=int, + default=[8] * default_layers, + nargs="+", + help="Number of transformer heads per transformer layer", + ) + + group.add_argument( + "--model.text.transformer.ffn-multiplier-per-layer", + type=float, + default=[4.0] * default_layers, + nargs="+", + help="FFN multiplier for each transformer layer", + ) + group.add_argument( + "--model.text.transformer.attn-dropout", + type=float, + default=0.0, + help="Dropout in multi-head attention", + ) + group.add_argument( + "--model.text.transformer.ffn-dropout", + type=float, + default=0.0, + help="Dropout between linear layers in FFN", + ) + group.add_argument( + "--model.text.transformer.dropout", + type=float, + default=0.0, + help="Dropout in transformer", + ) + + group.add_argument( + "--model.text.transformer.norm-layer", + type=str, + default="layer_norm", + help="Normalization layer", + ) + + group.add_argument( + "--model.text.transformer.sinusoidal-pos-emb", + action="store_true", + help="Use sinusoidal positional embedding", + ) + + group.add_argument( + "--model.text.transformer.gradient-checkpoint", + action="store_true", + help="Use gradient checkpointing", + ) + group.add_argument( + "--model.text.transformer.num-checkpoint-segments", + type=int, + default=1, + help="Number of gradient checkpoint segments", + ) + + group.add_argument( + "--model.text.transformer.causal-masking", + action="store_true", + help="Use causal masking", + ) + + group.add_argument( + "--model.text.transformer.classes-per-split-zero-shot", + type=int, + default=20, + help="Divide zero-shot classes into these many chunks, for faster processing", + ) + + return parser + + def forward_embedding( + self, + text_tokens: Tensor, + ): + # [Batch, Seq_len] --> [Batch, Seq_len, hidden_dim] + token_emb = self.embedding_layer(text_tokens) + # token_emb = self.embed_scale * token_emb + seq_len = token_emb.shape[1] + if self.positional_embedding is not None: + token_emb = token_emb + self.positional_embedding(seq_len).to( + token_emb.dtype + ) + token_emb = self.embedding_dropout(token_emb) + return token_emb + + def build_attention_mask(self, context_length: int, batch_size: int): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(context_length, context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + if not self.use_pytorch_mha: + mask = mask.unsqueeze(0) # add dummy batch dimension + mask = mask.expand(batch_size, -1, -1) + return mask + + def encode_text( + self, + text_tokens: Tensor, + key_padding_mask: Optional[Tensor] = None, + *args, + **kwargs + ) -> Tensor: + # discrete tokens to continuous embeddings + # [Batch, Seq_len] --> [Batch, Seq_len, hidden_dim] + token_emb = self.forward_embedding(text_tokens) + + # [1, Seq_len, Seq_len] + attn_mask = None + if self.causal_masking: + attn_mask = self.build_attention_mask( + context_length=text_tokens.shape[1], batch_size=text_tokens.shape[0] + ) + attn_mask = attn_mask.to(device=token_emb.device, dtype=token_emb.dtype) + key_padding_mask = None + + if self.use_pytorch_mha: + # [Batch, Seq_len, hidden_dim] --> [Seq_len, Batch, hidden_dim] + # we will use PyTorch's multi-head attention, which uses sequence_first format + token_emb = token_emb.transpose(0, 1) + + for layer in self.transformer: + if self.gradient_ckpt: + token_emb = gradient_checkpoint_fn( + layer, token_emb, None, key_padding_mask, attn_mask + ) + else: + token_emb = layer( + token_emb, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + use_pytorch_mha=self.use_pytorch_mha, + ) + + # Apply layer norm + token_emb = self.final_layer_norm(token_emb) + + # take features from the eot embedding (eot_token is the highest number in each sequence) + if self.use_pytorch_mha: + token_emb = token_emb[ + text_tokens.argmax(dim=-1), torch.arange(text_tokens.shape[0]) + ] + else: + token_emb = token_emb[ + torch.arange(text_tokens.shape[0]), text_tokens.argmax(dim=-1) + ] + + token_emb = token_emb @ self.projection_layer + # normalize text features + token_emb = F.normalize(token_emb, dim=-1) + return token_emb + + def forward_zero_shot( + self, + text_tokens: Tensor, + key_padding_mask: Optional[Tensor] = None, + *args, + **kwargs + ) -> Tensor: + # In case of zero-shot evaluation, text tokens is of shape [Batch, num_classes, num_captions, context_length] + # For example, in the ImageNet dataset, we have 1000 classes, and for each class we generate certain number of + # captions (each caption with context_length tokens) + + if self.training: + raise NotImplementedError( + "Zero-shot evaluation is only supported with eval mode" + ) + + if text_tokens.ndim != 4: + logger.error( + "For zero-shot evaluation, expected size of text is [Batch, Num_classes, num_captions, context_len]" + ) + + batch_size, num_classes, num_captions, context_len = text_tokens.shape + + # for zero-shot evaluation, text templates are the same across all images in the batch + # Therefore, batch size should be 1. + if batch_size != 1: + logger.error( + "For zero-shot evaluation, text templates are the same across all images in the batch." + "Therefore, batch size should be 1. Got: {}".format(batch_size) + ) + + text_features = [] + + for start_idx in range(0, num_classes, self.classes_per_split_zero_shot): + end_idx = min(start_idx + self.classes_per_split_zero_shot, num_classes) + + text_tokens_split = text_tokens[0, start_idx:end_idx, ...] + num_classes_split = text_tokens_split.shape[0] + text_tokens_split = text_tokens_split.reshape( + num_classes_split * num_captions, context_len + ) + + key_padding_mask_split = None + if key_padding_mask is not None: + key_padding_mask_split = key_padding_mask[0, start_idx:end_idx, ...] + key_padding_mask_split = key_padding_mask_split.reshape( + num_classes_split * num_captions, context_len + ) + + # [num_classes_per_split * num_cations, context_len] --> [num_classes_per_split * num_cations, latent_dim] + class_embedding_split = self.encode_text( + text_tokens=text_tokens_split, key_padding_mask=key_padding_mask_split + ) + + # [num_classes_per_split * num_cations, latent_dim] --> [num_classes_per_split, num_cations, latent_dim] + class_embedding_split = class_embedding_split.reshape( + num_classes_split, num_captions, class_embedding_split.shape[-1] + ) + + # Compute mean of all classes + # [num_classes_per_split, num_cations, latent_dim] --> [num_classes_per_split, latent_dim] + mean_class_embedding_split = class_embedding_split.mean(dim=1) + + # Normalize the embeddings + mean_class_embedding_split = F.normalize(mean_class_embedding_split, dim=-1) + + text_features.append(mean_class_embedding_split) + + # [num_classes_per_split, latent_dim] * num_splits --> [num_classes, Latent_dim] + text_features = torch.cat(text_features, dim=0) + # [num_classes, Latent_dim] --> [Latent_dim, num_classes] + text_features = text_features.transpose(0, 1) + return text_features + + def forward( + self, + text_tokens: Tensor, + key_padding_mask: Optional[Tensor] = None, + *args, + **kwargs + ) -> Tensor: + + if text_tokens.dim() == 4: + # It's for zero-shot evaluation. + # Each class in the dataset has multiple captions + return self.forward_zero_shot( + text_tokens=text_tokens, + key_padding_mask=key_padding_mask, + *args, + **kwargs + ) + elif text_tokens.dim() == 2: + # Image-text pair data with single caption + # [B, CL] --> [B, d] + text_tokens = self.encode_text( + text_tokens=text_tokens, + key_padding_mask=key_padding_mask, + *args, + **kwargs + ) + return text_tokens + else: + raise NotImplementedError diff --git a/Adaptive Frequency Filters/common/__init__.py b/Adaptive Frequency Filters/common/__init__.py new file mode 100644 index 0000000..bd4f095 --- /dev/null +++ b/Adaptive Frequency Filters/common/__init__.py @@ -0,0 +1,25 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +MIN_TORCH_VERSION = "1.11.0" + +SUPPORTED_IMAGE_EXTNS = [".png", ".jpg", ".jpeg"] # Add image formats here +SUPPORTED_MODALITIES = ["image", "video"] +SUPPORTED_VIDEO_CLIP_VOTING_FN = ["sum", "max"] +SUPPORTED_VIDEO_READER = ["pyav", "decord"] + +DEFAULT_IMAGE_WIDTH = DEFAULT_IMAGE_HEIGHT = 256 +DEFAULT_IMAGE_CHANNELS = 3 +DEFAULT_VIDEO_FRAMES = 8 +DEFAULT_LOG_FREQ = 500 + +DEFAULT_ITERATIONS = 300000 +DEFAULT_EPOCHS = 300 +DEFAULT_MAX_ITERATIONS = DEFAULT_MAX_EPOCHS = 10000000 + +TMP_RES_FOLDER = "results_tmp" + +TMP_CACHE_LOC = "/tmp" diff --git a/Adaptive Frequency Filters/data/__init__.py b/Adaptive Frequency Filters/data/__init__.py new file mode 100644 index 0000000..06e77ab --- /dev/null +++ b/Adaptive Frequency Filters/data/__init__.py @@ -0,0 +1,7 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from .data_loaders import create_train_val_loader, create_eval_loader diff --git a/Adaptive Frequency Filters/data/collate_fns/__init__.py b/Adaptive Frequency Filters/data/collate_fns/__init__.py new file mode 100644 index 0000000..bd1e3d7 --- /dev/null +++ b/Adaptive Frequency Filters/data/collate_fns/__init__.py @@ -0,0 +1,95 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import importlib +import argparse + +COLLATE_FN_REGISTRY = {} + + +def register_collate_fn(name): + def register_collate_fn_method(f): + if name in COLLATE_FN_REGISTRY: + raise ValueError( + "Cannot register duplicate collate function ({})".format(name) + ) + COLLATE_FN_REGISTRY[name] = f + return f + + return register_collate_fn_method + + +def arguments_collate_fn(parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Collate function arguments", description="Collate function arguments" + ) + group.add_argument( + "--dataset.collate-fn-name-train", + type=str, + default="default_collate_fn", + help="Name of collate function", + ) + group.add_argument( + "--dataset.collate-fn-name-val", + type=str, + default="default_collate_fn", + help="Name of collate function", + ) + group.add_argument( + "--dataset.collate-fn-name-eval", + type=str, + default=None, + help="Name of collate function used for evaluation. " + "Default is None, i.e., use PyTorch's inbuilt collate function", + ) + return parser + + +def build_collate_fn(opts, *args, **kwargs): + collate_fn_name_train = getattr( + opts, "dataset.collate_fn_name_train", "default_collate_fn" + ) + collate_fn_name_val = getattr( + opts, "dataset.collate_fn_name_val", "default_collate_fn" + ) + collate_fn_train = None + if ( + collate_fn_name_train is not None + and collate_fn_name_train in COLLATE_FN_REGISTRY + ): + collate_fn_train = COLLATE_FN_REGISTRY[collate_fn_name_train] + + collate_fn_val = None + if collate_fn_name_val is None: + collate_fn_val = collate_fn_name_train + elif collate_fn_name_val is not None and collate_fn_name_val in COLLATE_FN_REGISTRY: + collate_fn_val = COLLATE_FN_REGISTRY[collate_fn_name_val] + + return collate_fn_train, collate_fn_val + + +def build_eval_collate_fn(opts, *args, **kwargs): + collate_fn_name_eval = getattr(opts, "dataset.collate_fn_name_eval", None) + collate_fn_eval = None + if collate_fn_name_eval is not None and collate_fn_name_eval in COLLATE_FN_REGISTRY: + collate_fn_eval = COLLATE_FN_REGISTRY[collate_fn_name_eval] + + return collate_fn_eval + + +# automatically import the augmentations +collate_fn_dir = os.path.dirname(__file__) + +for file in os.listdir(collate_fn_dir): + path = os.path.join(collate_fn_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + collate_fn_fname = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("data.collate_fns." + collate_fn_fname) diff --git a/Adaptive Frequency Filters/data/collate_fns/collate_functions.py b/Adaptive Frequency Filters/data/collate_fns/collate_functions.py new file mode 100644 index 0000000..524a6df --- /dev/null +++ b/Adaptive Frequency Filters/data/collate_fns/collate_functions.py @@ -0,0 +1,43 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- +import numpy as np +import torch +from typing import List, Dict + +from utils import logger + +from . import register_collate_fn + + +@register_collate_fn(name="default_collate_fn") +def default_collate_fn(batch: List[Dict], opts): + """Default collate function""" + batch_size = len(batch) + + keys = list(batch[0].keys()) + + new_batch = {k: [] for k in keys} + for b in range(batch_size): + for k in keys: + new_batch[k].append(batch[b][k]) + + # stack the keys + for k in keys: + batch_elements = new_batch.pop(k) + + if isinstance(batch_elements[0], (int, float, np.integer, np.floating)): + # list of ints or floats + batch_elements = torch.as_tensor(batch_elements) + else: + # stack tensors (including 0-dimensional) + try: + batch_elements = torch.stack(batch_elements, dim=0).contiguous() + except Exception as e: + logger.error("Unable to stack the tensors. Error: {}".format(e)) + + new_batch[k] = batch_elements + + return new_batch diff --git a/Adaptive Frequency Filters/data/data_loaders.py b/Adaptive Frequency Filters/data/data_loaders.py new file mode 100644 index 0000000..205ef25 --- /dev/null +++ b/Adaptive Frequency Filters/data/data_loaders.py @@ -0,0 +1,138 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from functools import partial + +from utils import logger +from utils.ddp_utils import is_master +from utils.tensor_utils import image_size_from_opts + +from .datasets import train_val_datasets, evaluation_datasets +from .sampler import build_sampler +from .collate_fns import build_collate_fn, build_eval_collate_fn +from .loader.dataloader import affnetDataLoader + + +def create_eval_loader(opts): + eval_dataset = evaluation_datasets(opts) + n_eval_samples = len(eval_dataset) + is_master_node = is_master(opts) + + # overwrite the validation argument + setattr( + opts, "dataset.val_batch_size0", getattr(opts, "dataset.eval_batch_size0", 1) + ) + + # we don't need variable batch sampler for evaluation + sampler_name = getattr(opts, "sampler.name", "batch_sampler") + crop_size_h, crop_size_w = image_size_from_opts(opts) + if sampler_name.find("video") > -1 and sampler_name != "video_batch_sampler": + clips_per_video = getattr(opts, "sampler.vbs.clips_per_video", 1) + frames_per_clip = getattr(opts, "sampler.vbs.num_frames_per_clip", 8) + setattr(opts, "sampler.name", "video_batch_sampler") + setattr(opts, "sampler.bs.crop_size_width", crop_size_w) + setattr(opts, "sampler.bs.crop_size_height", crop_size_h) + setattr(opts, "sampler.bs.clips_per_video", clips_per_video) + setattr(opts, "sampler.bs.num_frames_per_clip", frames_per_clip) + elif sampler_name.find("var") > -1: + setattr(opts, "sampler.name", "batch_sampler") + setattr(opts, "sampler.bs.crop_size_width", crop_size_w) + setattr(opts, "sampler.bs.crop_size_height", crop_size_h) + + eval_sampler = build_sampler( + opts=opts, n_data_samples=n_eval_samples, is_training=False + ) + + collate_fn_eval = build_eval_collate_fn(opts=opts) + + data_workers = getattr(opts, "dataset.workers", 1) + persistent_workers = False + pin_memory = False + + eval_loader = affnetDataLoader( + dataset=eval_dataset, + batch_size=1, + batch_sampler=eval_sampler, + num_workers=data_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + collate_fn=partial(collate_fn_eval, opts=opts) + if collate_fn_eval is not None + else None, + ) + + if is_master_node: + logger.log("Evaluation sampler details: ") + print("{}".format(eval_sampler)) + + return eval_loader + + +def create_train_val_loader(opts): + train_dataset, valid_dataset = train_val_datasets(opts) + + n_train_samples = len(train_dataset) + is_master_node = is_master(opts) + + train_sampler = build_sampler( + opts=opts, n_data_samples=n_train_samples, is_training=True + ) + if valid_dataset is not None: + n_valid_samples = len(valid_dataset) + valid_sampler = build_sampler( + opts=opts, n_data_samples=n_valid_samples, is_training=False + ) + else: + valid_sampler = None + + data_workers = getattr(opts, "dataset.workers", 1) + persistent_workers = getattr(opts, "dataset.persistent_workers", False) and ( + data_workers > 0 + ) + pin_memory = getattr(opts, "dataset.pin_memory", False) + prefetch_factor = getattr(opts, "dataset.prefetch_factor", 2) + + collate_fn_train, collate_fn_val = build_collate_fn(opts=opts) + + train_loader = affnetDataLoader( + dataset=train_dataset, + batch_size=1, # Handled inside data sampler + num_workers=data_workers, + pin_memory=pin_memory, + batch_sampler=train_sampler, + persistent_workers=persistent_workers, + collate_fn=partial(collate_fn_train, opts=opts) + if collate_fn_train is not None + else None, + prefetch_factor=prefetch_factor, + ) + + if valid_dataset is not None: + val_loader = affnetDataLoader( + dataset=valid_dataset, + batch_size=1, + batch_sampler=valid_sampler, + num_workers=data_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + collate_fn=partial(collate_fn_val, opts=opts) + if collate_fn_val is not None + else None, + ) + else: + val_loader = None + + if is_master_node: + logger.log("Training sampler details: ") + print("{}".format(train_sampler)) + + if valid_dataset is not None: + logger.log("Validation sampler details: ") + print("{}".format(valid_sampler)) + logger.log("Number of data workers: {}".format(data_workers)) + + return train_loader, val_loader, train_sampler diff --git a/Adaptive Frequency Filters/data/datasets/__init__.py b/Adaptive Frequency Filters/data/datasets/__init__.py new file mode 100644 index 0000000..a45b4ef --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/__init__.py @@ -0,0 +1,292 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import importlib +import argparse +import glob + +from utils.ddp_utils import is_master +from utils import logger + +from .dataset_base import BaseImageDataset +from .multi_modal_img_text import arguments_multi_modal_img_text + + +SUPPORTED_TASKS = [] +DATASET_REGISTRY = {} + +SEPARATOR = ":" + + +def register_dataset(name, task): + def register_dataset_class(cls): + if name in DATASET_REGISTRY: + raise ValueError( + "Cannot register duplicate dataset class ({})".format(name) + ) + + if not issubclass(cls, BaseImageDataset): + raise ValueError( + "Dataset ({}: {}) must extend BaseImageDataset".format( + name, cls.__name__ + ) + ) + + DATASET_REGISTRY[name + SEPARATOR + task] = cls + return cls + + return register_dataset_class + + +def supported_dataset_str(dataset_name, dataset_category): + supp_list = list(DATASET_REGISTRY.keys()) + supp_str = "Dataset ({}) under task ({}) is not yet supported. \n Supported datasets are:".format( + dataset_name, dataset_category + ) + for t_name in SUPPORTED_TASKS: + supp_str += "\n\t {}: ".format(logger.color_text(t_name)) + for i, m_name in enumerate(supp_list): + d_name, t_name1 = m_name.split(SEPARATOR) + if t_name == t_name1: + supp_str += "\n\t\t{}".format(d_name) + logger.error(supp_str + "\n") + + +def evaluation_datasets(opts): + dataset_name = getattr(opts, "dataset.name", "imagenet") + dataset_category = getattr(opts, "dataset.category", "classification") + + is_master_node = is_master(opts) + + name_dataset_task = dataset_name + SEPARATOR + dataset_category + eval_dataset = None + if name_dataset_task in DATASET_REGISTRY: + eval_dataset = DATASET_REGISTRY[name_dataset_task]( + opts=opts, is_training=False, is_evaluation=True + ) + else: + supported_dataset_str( + dataset_name=dataset_name, dataset_category=dataset_category + ) + + if is_master_node: + logger.log("Evaluation dataset details: ") + print("{}".format(eval_dataset)) + + return eval_dataset + + +def train_val_datasets(opts): + dataset_name = getattr(opts, "dataset.name", "imagenet") + dataset_category = getattr(opts, "dataset.category", "classification") + disable_val = getattr(opts, "dataset.disable_val", False) + + is_master_node = is_master(opts) + + name_dataset_task = dataset_name + SEPARATOR + dataset_category + train_dataset = valid_dataset = None + if name_dataset_task in DATASET_REGISTRY and not disable_val: + train_dataset = DATASET_REGISTRY[name_dataset_task](opts=opts, is_training=True) + valid_dataset = DATASET_REGISTRY[name_dataset_task]( + opts=opts, is_training=False + ) + elif name_dataset_task in DATASET_REGISTRY and disable_val: + train_dataset = DATASET_REGISTRY[name_dataset_task](opts=opts, is_training=True) + valid_dataset = None + else: + supported_dataset_str( + dataset_name=dataset_name, dataset_category=dataset_category + ) + + if is_master_node: + logger.log("Training and validation dataset details: ") + print("{}".format(train_dataset)) + print("{}".format(valid_dataset)) + return train_dataset, valid_dataset + + + +def general_dataset_args(parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Dataset", description="Arguments related to dataset" + ) + group.add_argument( + "--dataset.root-train", + type=str, + default="", + help="Root location of train dataset", + ) + group.add_argument( + "--dataset.root-val", + type=str, + default="", + help="Root location of valid dataset", + ) + group.add_argument( + "--dataset.root-test", + type=str, + default="", + help="Root location of test dataset", + ) + group.add_argument( + "--dataset.disable-val", action="store_true", help="Disable validation" + ) + + group.add_argument( + "--dataset.name", type=str, default="imagenet", help="Dataset name" + ) + group.add_argument( + "--dataset.category", + type=str, + default="classification", + help="Dataset category (e.g., segmentation, classification)", + ) + group.add_argument( + "--dataset.train-batch-size0", default=128, type=int, help="Training batch size" + ) + group.add_argument( + "--dataset.val-batch-size0", default=1, type=int, help="Validation batch size" + ) + group.add_argument( + "--dataset.eval-batch-size0", default=1, type=int, help="Validation batch size" + ) + group.add_argument( + "--dataset.workers", default=-1, type=int, help="Number of data workers" + ) + group.add_argument( + "--dataset.dali-workers", + default=-1, + type=int, + help="Number of data workers for dali", + ) + group.add_argument( + "--dataset.persistent-workers", + action="store_true", + help="Use same workers across all epochs in data loader", + ) + group.add_argument( + "--dataset.pin-memory", + action="store_true", + help="Use pin memory option in data loader", + ) + group.add_argument( + "--dataset.prefetch-factor", + type=int, + default=2, + help="Number of samples loaded in advance by each data worker", + ) + group.add_argument( + "--dataset.img-dtype", + type=str, + choices=["float", "half", "float16"], + default="float", + help="Image datatype", + ) + + group.add_argument( + "--dataset.cache-images-on-ram", action="store_true", help="Cache data on RAM" + ) + group.add_argument( + "--dataset.cache-limit", + type=float, + default=80.0, + help="Max. memory to use in RAM.", + ) + + # sample efficient training + group.add_argument( + "--dataset.sample-efficient-training.enable", + action="store_true", + help="sample efficient training", + ) + group.add_argument( + "--dataset.sample-efficient-training.sample-confidence", + type=float, + default=0.5, + help="Confidence for sample", + ) + group.add_argument( + "--dataset.sample-efficient-training.find-easy-samples-every-k-epochs", + type=int, + default=5, + help="Find easy samples after every K epochs", + ) + group.add_argument( + "--dataset.sample-efficient-training.min-sample-frequency", + type=int, + default=5, + help="Frequency that sample has been classified as easy for N number of times.", + ) + + group.add_argument( + "--dataset.decode-data-on-gpu", action="store_true", help="Decode data on GPU" + ) + group.add_argument( + "--dataset.sampler-type", + type=str, + default="batch", + help="Batch sampler or not.", + ) + + group.add_argument( + "--dataset.padding-index", + type=int, + default=None, + help="Padding index for text vocabulary", + ) + + group.add_argument( + "--dataset.text-vocab-size", type=int, default=-1, help="Text vocabulary size" + ) + + return parser + + +def arguments_dataset(parser: argparse.ArgumentParser): + parser = general_dataset_args(parser=parser) + + try: + from internal.utils.server_utils import dataset_server_args + parser = dataset_server_args(parser) + except ImportError as e: + pass + + # add zero-shot arguments + parser = arguments_multi_modal_img_text(parser=parser) + + # add dataset specific arguments + for k, v in DATASET_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + + return parser + + +# automatically import the datasets +dataset_dir = os.path.dirname(__file__) + +# supported tasks (each folder in datasets is for a particular task) +for abs_dir_path in glob.glob("{}/*".format(dataset_dir)): + if os.path.isdir(abs_dir_path): + file_or_folder_name = os.path.basename(abs_dir_path).strip() + if not file_or_folder_name.startswith( + "_" + ) and not file_or_folder_name.startswith("."): + SUPPORTED_TASKS.append(file_or_folder_name) + +for task in SUPPORTED_TASKS: + task_path = os.path.join(dataset_dir, task) + for file in os.listdir(task_path): + path = os.path.join(task_path, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + dataset_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module( + "data.datasets." + task + "." + dataset_name + ) diff --git a/Adaptive Frequency Filters/data/datasets/classification/__init__.py b/Adaptive Frequency Filters/data/datasets/classification/__init__.py new file mode 100644 index 0000000..ce25842 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/classification/__init__.py @@ -0,0 +1,9 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +""" +Image Classification Datasets +""" diff --git a/Adaptive Frequency Filters/data/datasets/classification/imagenet.py b/Adaptive Frequency Filters/data/datasets/classification/imagenet.py new file mode 100644 index 0000000..4b93f5d --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/classification/imagenet.py @@ -0,0 +1,221 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torchvision.datasets import ImageFolder +from typing import Optional, Tuple, Dict, List, Union +import torch +import argparse + +from utils import logger + +from .. import register_dataset +from ..dataset_base import BaseImageDataset +from ...transforms import image_pil as T +from ...collate_fns import register_collate_fn + + +@register_dataset(name="imagenet", task="classification") +class ImagenetDataset(BaseImageDataset, ImageFolder): + """ + ImageNet Classification Dataset that uses PIL for reading and augmenting images. The dataset structure should + follow the ImageFolder class in :class:`torchvision.datasets.imagenet` + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + + .. note:: + We recommend to use this dataset class over the imagenet_opencv.py file. + + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + BaseImageDataset.__init__( + self, opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + root = self.root + ImageFolder.__init__( + self, root=root, transform=None, target_transform=None, is_valid_file=None + ) + + self.n_classes = len(list(self.class_to_idx.keys())) + setattr(opts, "model.classification.n_classes", self.n_classes) + setattr(opts, "dataset.collate_fn_name_train", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_val", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_eval", "imagenet_collate_fn") + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Add dataset-specific arguments to the parser.""" + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--dataset.imagenet.crop-ratio", + type=float, + default=0.875, + help="Crop ratio", + ) + return parser + + def _training_transforms(self, size: Union[Tuple, int], *args, **kwargs): + """ + Training data augmentation methods. + Image --> RandomResizedCrop --> RandomHorizontalFlip --> Optional(AutoAugment or RandAugment) + --> Tensor --> Optional(RandomErasing) --> Optional(MixUp) --> Optional(CutMix) + + .. note:: + 1. AutoAugment, RandAugment and TrivialAugmentWide are mutually exclusive. + 2. Mixup and CutMix are applied on batches are implemented in trainer. + """ + aug_list = [ + T.RandomResizedCrop(opts=self.opts, size=size), + T.RandomHorizontalFlip(opts=self.opts), + ] + auto_augment = getattr( + self.opts, "image_augmentation.auto_augment.enable", False + ) + rand_augment = getattr( + self.opts, "image_augmentation.rand_augment.enable", False + ) + trivial_augment_wide = getattr( + self.opts, "image_augmentation.trivial_augment_wide.enable", False + ) + if bool(auto_augment) + bool(rand_augment) + bool(trivial_augment_wide) > 1: + logger.error( + "AutoAugment, RandAugment and TrivialAugmentWide are mutually exclusive. Use either of them, but not more than one" + ) + elif auto_augment: + aug_list.append(T.AutoAugment(opts=self.opts)) + elif rand_augment: + if getattr( + self.opts, "image_augmentation.rand_augment.use_timm_library", False + ): + aug_list.append(T.RandAugmentTimm(opts=self.opts)) + else: + aug_list.append(T.RandAugment(opts=self.opts)) + elif trivial_augment_wide: + aug_list.append(T.TrivialAugmentWide(opts=self.opts)) + + aug_list.append(T.ToTensor(opts=self.opts)) + + if getattr(self.opts, "image_augmentation.random_erase.enable", False): + aug_list.append(T.RandomErasing(opts=self.opts)) + + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: Union[Tuple, int], *args, **kwargs): + """ + Validation augmentation + Image --> Resize --> CenterCrop --> ToTensor + """ + aug_list = [ + T.Resize(opts=self.opts), + T.CenterCrop(opts=self.opts), + T.ToTensor(opts=self.opts), + ] + + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _evaluation_transforms(self, size: Union[Tuple, int], *args, **kwargs): + """Same as the validation_transforms""" + return self._validation_transforms(size=size) + + def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: + """ + :param batch_indexes_tup: Tuple of the form (Crop_size_W, Crop_size_H, Image_ID) + :return: dictionary containing input image, label, and sample_id. + """ + crop_size_h, crop_size_w, img_index = batch_indexes_tup + if self.is_training: + transform_fn = self._training_transforms(size=(crop_size_h, crop_size_w)) + else: + # same for validation and evaluation + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + img_path, target = self.samples[img_index] + + input_img = self.read_image_pil(img_path) + + if input_img is None: + # Sometimes images are corrupt + # Skip such images + logger.log("Img index {} is possibly corrupt.".format(img_index)) + input_tensor = torch.zeros( + size=(3, crop_size_h, crop_size_w), dtype=self.img_dtype + ) + target = -1 + data = {"image": input_tensor} + else: + data = {"image": input_img} + data = transform_fn(data) + + data["samples"] = data.pop("image") + data["targets"] = target + data["sample_id"] = img_index + + return data + + def __len__(self) -> int: + return len(self.samples) + + def __repr__(self) -> str: + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\tn_classes={}\n\ttransforms={}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.samples), + self.n_classes, + transforms_str, + ) + + +@register_collate_fn(name="imagenet_collate_fn") +def imagenet_collate_fn(batch: List, opts) -> Dict: + batch_size = len(batch) + img_size = [batch_size, *batch[0]["samples"].shape] + img_dtype = batch[0]["samples"].dtype + + images = torch.zeros(size=img_size, dtype=img_dtype) + # fill with -1, so that we can ignore corrupted images + labels = torch.full(size=[batch_size], fill_value=-1, dtype=torch.long) + sample_ids = torch.zeros(size=[batch_size], dtype=torch.long) + valid_indexes = [] + for i, batch_i in enumerate(batch): + label_i = batch_i.pop("targets") + images[i] = batch_i.pop("samples") + labels[i] = label_i # label is an int + sample_ids[i] = batch_i.pop("sample_id") # sample id is an int + if label_i != -1: + valid_indexes.append(i) + + valid_indexes = torch.tensor(valid_indexes, dtype=torch.long) + images = torch.index_select(images, dim=0, index=valid_indexes) + labels = torch.index_select(labels, dim=0, index=valid_indexes) + sample_ids = torch.index_select(sample_ids, dim=0, index=valid_indexes) + + channels_last = getattr(opts, "common.channels_last", False) + if channels_last: + images = images.to(memory_format=torch.channels_last) + + return {"samples": images, "targets": labels, "sample_id": sample_ids} diff --git a/Adaptive Frequency Filters/data/datasets/classification/imagenet_fast.py b/Adaptive Frequency Filters/data/datasets/classification/imagenet_fast.py new file mode 100644 index 0000000..1467dcb --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/classification/imagenet_fast.py @@ -0,0 +1,198 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +# from torchvision.datasets import ImageFolder +from utils.my_dataset_folder import ImageFolder +import os +from typing import Optional, Tuple, Dict, List, Union +import torch +import argparse + +from utils import logger + +from .. import register_dataset +from ..dataset_base import BaseImageDataset +from ...transforms import image_pil as T +from ...collate_fns import register_collate_fn + + +@register_dataset(name="imagenet_fast", task="classification") +class ImagenetDataset(BaseImageDataset, ImageFolder): + """ + ImageNet Classification Dataset that uses PIL for reading and augmenting images. The dataset structure should + follow the ImageFolder class in :class:`torchvision.datasets.imagenet` + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + + .. note:: + We recommend to use this dataset class over the imagenet_opencv.py file. + + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + BaseImageDataset.__init__( + self, opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + root = self.root + # ImageFolder.__init__( + # self, root=root, transform=None, target_transform=None, is_valid_file=None + # ) + # assert is_training ^ is_evaluation + prefix = 'train' if is_training else 'val' + map_txt = os.path.join(root, '..', f"{prefix}_map.txt") + ImageFolder.__init__( + self, root=root, transform=None, target_transform=None, is_valid_file=None, map_txt=map_txt + ) + # self.n_classes = len(list(self.class_to_idx.keys())) + self.n_classes = len(self.classes) + setattr(opts, "model.classification.n_classes", self.n_classes) + setattr(opts, "dataset.collate_fn_name_train", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_val", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_eval", "imagenet_collate_fn") + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Add dataset-specific arguments to the parser.""" + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--dataset.imagenet_fast.crop-ratio", # --dataset.imagenet.crop-ratio + type=float, + default=0.875, + help="Crop ratio", + ) + return parser + + def _training_transforms(self, size: Union[Tuple, int], *args, **kwargs): + """ + Training data augmentation methods. + Image --> RandomResizedCrop --> RandomHorizontalFlip --> Optional(AutoAugment or RandAugment) + --> Tensor --> Optional(RandomErasing) --> Optional(MixUp) --> Optional(CutMix) + + .. note:: + 1. AutoAugment, RandAugment and TrivialAugmentWide are mutually exclusive. + 2. Mixup and CutMix are applied on batches are implemented in trainer. + """ + aug_list = [ + T.RandomResizedCrop(opts=self.opts, size=size), + T.RandomHorizontalFlip(opts=self.opts), + ] + auto_augment = getattr( + self.opts, "image_augmentation.auto_augment.enable", False + ) + rand_augment = getattr( + self.opts, "image_augmentation.rand_augment.enable", False + ) + trivial_augment_wide = getattr( + self.opts, "image_augmentation.trivial_augment_wide.enable", False + ) + if bool(auto_augment) + bool(rand_augment) + bool(trivial_augment_wide) > 1: + logger.error( + "AutoAugment, RandAugment and TrivialAugmentWide are mutually exclusive. Use either of them, but not more than one" + ) + elif auto_augment: + aug_list.append(T.AutoAugment(opts=self.opts)) + elif rand_augment: + if getattr( + self.opts, "image_augmentation.rand_augment.use_timm_library", False + ): + aug_list.append(T.RandAugmentTimm(opts=self.opts)) + else: + aug_list.append(T.RandAugment(opts=self.opts)) + elif trivial_augment_wide: + aug_list.append(T.TrivialAugmentWide(opts=self.opts)) + + aug_list.append(T.ToTensor(opts=self.opts)) + + if getattr(self.opts, "image_augmentation.random_erase.enable", False): + aug_list.append(T.RandomErasing(opts=self.opts)) + + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: Union[Tuple, int], *args, **kwargs): + """ + Validation augmentation + Image --> Resize --> CenterCrop --> ToTensor + """ + aug_list = [ + T.Resize(opts=self.opts), + T.CenterCrop(opts=self.opts), + T.ToTensor(opts=self.opts), + ] + + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _evaluation_transforms(self, size: Union[Tuple, int], *args, **kwargs): + """Same as the validation_transforms""" + return self._validation_transforms(size=size) + + def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: + """ + :param batch_indexes_tup: Tuple of the form (Crop_size_W, Crop_size_H, Image_ID) + :return: dictionary containing input image, label, and sample_id. + """ + crop_size_h, crop_size_w, img_index = batch_indexes_tup + if self.is_training: + transform_fn = self._training_transforms(size=(crop_size_h, crop_size_w)) + else: + # same for validation and evaluation + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + img_path, target = self.samples[img_index] + + input_img = self.read_image_pil(img_path) + + if input_img is None: + # Sometimes images are corrupt + # Skip such images + logger.log("Img index {} is possibly corrupt.".format(img_index)) + input_tensor = torch.zeros( + size=(3, crop_size_h, crop_size_w), dtype=self.img_dtype + ) + target = -1 + data = {"image": input_tensor} + else: + data = {"image": input_img} + data = transform_fn(data) + + data["samples"] = data.pop("image") + data["targets"] = target + data["sample_id"] = img_index + + return data + + def __len__(self) -> int: + return len(self.samples) + + def __repr__(self) -> str: + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\tn_classes={}\n\ttransforms={}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.samples), + self.n_classes, + transforms_str, + ) diff --git a/Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv.py b/Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv.py new file mode 100644 index 0000000..2ae6008 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv.py @@ -0,0 +1,162 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torchvision.datasets import ImageFolder +from typing import Optional, Tuple, Dict +import numpy as np +import math +import warnings + +from utils import logger + +from .. import register_dataset +from ..dataset_base import BaseImageDataset +from ...transforms import image_opencv as tf + + +@register_dataset(name="imagenet_opencv", task="classification") +class ImagenetOpenCVDataset(BaseImageDataset, ImageFolder): + """ + ImageNet Classification Dataset that uses OpenCV for data augmentation. + + The dataset structure should follow the ImageFolder class in :class:`torchvision.datasets.imagenet` + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + + .. note:: + This class is depreciated and will be removed in future versions (Use it for MobileViT evaluation). + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + warnings.warn( + "The use of dataset.name=imagenet_opencv is depreciated. Please use dataset.name=imagenet", + DeprecationWarning, + ) + BaseImageDataset.__init__( + self, opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + root = self.root + ImageFolder.__init__( + self, root=root, transform=None, target_transform=None, is_valid_file=None + ) + + self.n_classes = len(list(self.class_to_idx.keys())) + setattr(opts, "model.classification.n_classes", self.n_classes) + + setattr(opts, "dataset.collate_fn_name_train", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_val", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_eval", "imagenet_collate_fn") + + def _training_transforms(self, size: tuple or int): + """ + Training data augmentation methods (RandomResizedCrop --> RandomHorizontalFlip --> ToTensor). + """ + aug_list = [ + tf.RandomResizedCrop(opts=self.opts, size=size), + tf.RandomHorizontalFlip(opts=self.opts), + tf.NumpyToTensor(opts=self.opts), + ] + return tf.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: tuple): + """Implements validation transformation method (Resize --> CenterCrop --> ToTensor).""" + if isinstance(size, (tuple, list)): + size = min(size) + + assert isinstance(size, int) + # (256 - 224) = 32 + # where 224/0.875 = 256 + + crop_ratio = getattr(self.opts, "dataset.imagenet.crop_ratio", 0.875) + if 0 < crop_ratio < 1.0: + scale_size = int(math.ceil(size / crop_ratio)) + scale_size = (scale_size // 32) * 32 + else: + logger.warning( + "Crop ratio should be between 0 and 1. Got: {}".format(crop_ratio) + ) + logger.warning("Setting scale_size as size + 32") + scale_size = size + 32 # int(make_divisible(crop_size / 0.875, divisor=32)) + + return tf.Compose( + opts=self.opts, + img_transforms=[ + tf.Resize(opts=self.opts, size=scale_size), + tf.CenterCrop(opts=self.opts, size=size), + tf.NumpyToTensor(opts=self.opts), + ], + ) + + def _evaluation_transforms(self, size: tuple): + """Same as the validation_transforms""" + return self._validation_transforms(size=size) + + def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: + """ + + :param batch_indexes_tup: Tuple of the form (Crop_size_W, Crop_size_H, Image_ID) + :return: dictionary containing input image and label ID. + """ + crop_size_h, crop_size_w, img_index = batch_indexes_tup + if self.is_training: + transform_fn = self._training_transforms(size=(crop_size_h, crop_size_w)) + else: # same for validation and evaluation + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + img_path, target = self.samples[img_index] + input_img = self.read_image_opencv(img_path) + + if input_img is None: + # Sometimes images are corrupt and cv2 is not able to load them + # Skip such images + logger.log( + "Img index {} is possibly corrupt. Removing it from the sample list".format( + img_index + ) + ) + del self.samples[img_index] + input_img = np.zeros(shape=(crop_size_h, crop_size_w, 3), dtype=np.uint8) + + data = {"image": input_img} + data = transform_fn(data) + + data["samples"] = data.pop("image") + data["targets"] = target + data["sample_id"] = img_index + + return data + + def __len__(self): + return len(self.samples) + + def __repr__(self): + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\tn_classes={}\n\ttransforms={}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.samples), + self.n_classes, + transforms_str, + ) diff --git a/Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv_bitplane_fast.py b/Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv_bitplane_fast.py new file mode 100644 index 0000000..1b75a36 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv_bitplane_fast.py @@ -0,0 +1,159 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- +import os + +from utils.my_dataset_folder import ImageFolder +from typing import Optional, Tuple, Dict +import numpy as np +import math + +from utils import logger + +from .. import register_dataset +from ..dataset_base import BaseImageDataset +from ...transforms import image_opencv as tf +from ...transforms.image_opencv import BitPlane + +# change name +@register_dataset(name="imagenet_opencv_bitplane_fast", task="classification") +class ImagenetOpenCVDataset(BaseImageDataset, ImageFolder): + """ + ImageNet Classification Dataset that uses OpenCV for data augmentation. + + The dataset structure should follow the ImageFolder class in :class:`torchvision.datasets.imagenet` + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + + .. note:: + This class is depreciated and will be removed in future versions (Use it for MobileViT evaluation). + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + BaseImageDataset.__init__( + self, opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + root = self.root + # assert is_training ^ is_evaluation + prefix = 'train' if is_training else 'val' + map_txt = os.path.join(root, '..', f"{prefix}_map.txt") + ImageFolder.__init__( + self, root=root, transform=None, target_transform=None, is_valid_file=None, map_txt=map_txt + ) + + self.n_classes = len(self.classes) + setattr(opts, "model.classification.n_classes", self.n_classes) + + def _training_transforms(self, size: tuple or int): + """ + Training data augmentation methods (RandomResizedCrop --> RandomHorizontalFlip --> ToTensor). + """ + aug_list = [ + tf.RandomResizedCrop(opts=self.opts, size=size), + tf.RandomHorizontalFlip(opts=self.opts), + BitPlane(opts=self.opts, h=size[0], w=size[1]), + tf.NumpyToTensor(opts=self.opts), + ] + return tf.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: tuple): + """Implements validation transformation method (Resize --> CenterCrop --> ToTensor).""" + if isinstance(size, (tuple, list)): + size = min(size) + + assert isinstance(size, int) + # (256 - 224) = 32 + # where 224/0.875 = 256 + + crop_ratio = getattr(self.opts, "dataset.imagenet.crop_ratio", 0.875) + if 0 < crop_ratio < 1.0: + scale_size = int(math.ceil(size / crop_ratio)) + scale_size = (scale_size // 32) * 32 + else: + logger.warning( + "Crop ratio should be between 0 and 1. Got: {}".format(crop_ratio) + ) + logger.warning("Setting scale_size as size + 32") + scale_size = size + 32 # int(make_divisible(crop_size / 0.875, divisor=32)) + + return tf.Compose( + opts=self.opts, + img_transforms=[ + tf.Resize(opts=self.opts, size=scale_size), + tf.CenterCrop(opts=self.opts, size=size), + BitPlane(opts=self.opts, h=size, w=size), + tf.NumpyToTensor(opts=self.opts), + ], + ) + + def _evaluation_transforms(self, size: tuple): + """Same as the validation_transforms""" + return self._validation_transforms(size=size) + + def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: + """ + + :param batch_indexes_tup: Tuple of the form (Crop_size_W, Crop_size_H, Image_ID) + :return: dictionary containing input image and label ID. + """ + crop_size_h, crop_size_w, img_index = batch_indexes_tup + if self.is_training: + transform_fn = self._training_transforms(size=(crop_size_h, crop_size_w)) + else: # same for validation and evaluation + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + img_path, target = self.samples[img_index] + input_img = self.read_image_opencv(img_path) + + if input_img is None: + # Sometimes images are corrupt and cv2 is not able to load them + # Skip such images + logger.log( + "Img index {} is possibly corrupt. Removing it from the sample list".format( + img_index + ) + ) + del self.samples[img_index] + input_img = np.zeros(shape=(crop_size_h, crop_size_w, 3), dtype=np.uint8) + + data = {"image": input_img} + data = transform_fn(data) + + data["label"] = target + data["sample_id"] = img_index + + return data + + def __len__(self): + return len(self.samples) + + def __repr__(self): + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\tn_classes={}\n\ttransforms={}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.samples), + self.n_classes, + transforms_str, + ) diff --git a/Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv_fast.py b/Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv_fast.py new file mode 100644 index 0000000..211e230 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/classification/imagenet_opencv_fast.py @@ -0,0 +1,161 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- +import os + +from utils.my_dataset_folder import ImageFolder +from typing import Optional, Tuple, Dict +import numpy as np +import math + +from utils import logger + +from .. import register_dataset +from ..dataset_base import BaseImageDataset +from ...transforms import image_opencv as tf + + +@register_dataset(name="imagenet_opencv_fast", task="classification") +class ImagenetOpenCVDataset(BaseImageDataset, ImageFolder): + """ + ImageNet Classification Dataset that uses OpenCV for data augmentation. + + The dataset structure should follow the ImageFolder class in :class:`torchvision.datasets.imagenet` + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + + .. note:: + This class is depreciated and will be removed in future versions (Use it for MobileViT evaluation). + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + BaseImageDataset.__init__( + self, opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + root = self.root + # assert is_training ^ is_evaluation + prefix = 'train' if is_training else 'val' + map_txt = os.path.join(root, '..', f"{prefix}_map.txt") + ImageFolder.__init__( + self, root=root, transform=None, target_transform=None, is_valid_file=None, map_txt=map_txt + ) + + self.n_classes = len(self.classes) + setattr(opts, "model.classification.n_classes", self.n_classes) + + setattr(opts, "dataset.collate_fn_name_train", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_val", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_eval", "imagenet_collate_fn") + + def _training_transforms(self, size: tuple or int): + """ + Training data augmentation methods (RandomResizedCrop --> RandomHorizontalFlip --> ToTensor). + """ + aug_list = [ + tf.RandomResizedCrop(opts=self.opts, size=size), + tf.RandomHorizontalFlip(opts=self.opts), + tf.NumpyToTensor(opts=self.opts), + ] + return tf.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: tuple): + """Implements validation transformation method (Resize --> CenterCrop --> ToTensor).""" + if isinstance(size, (tuple, list)): + size = min(size) + + assert isinstance(size, int) + # (256 - 224) = 32 + # where 224/0.875 = 256 + + crop_ratio = getattr(self.opts, "dataset.imagenet.crop_ratio", 0.875) + if 0 < crop_ratio < 1.0: + scale_size = int(math.ceil(size / crop_ratio)) + scale_size = (scale_size // 32) * 32 + else: + logger.warning( + "Crop ratio should be between 0 and 1. Got: {}".format(crop_ratio) + ) + logger.warning("Setting scale_size as size + 32") + scale_size = size + 32 # int(make_divisible(crop_size / 0.875, divisor=32)) + + return tf.Compose( + opts=self.opts, + img_transforms=[ + tf.Resize(opts=self.opts, size=scale_size), + tf.CenterCrop(opts=self.opts, size=size), + tf.NumpyToTensor(opts=self.opts), + ], + ) + + def _evaluation_transforms(self, size: tuple): + """Same as the validation_transforms""" + return self._validation_transforms(size=size) + + def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: + """ + + :param batch_indexes_tup: Tuple of the form (Crop_size_W, Crop_size_H, Image_ID) + :return: dictionary containing input image and label ID. + """ + crop_size_h, crop_size_w, img_index = batch_indexes_tup + if self.is_training: + transform_fn = self._training_transforms(size=(crop_size_h, crop_size_w)) + else: # same for validation and evaluation + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + img_path, target = self.samples[img_index] + input_img = self.read_image_opencv(img_path) + + if input_img is None: + # Sometimes images are corrupt and cv2 is not able to load them + # Skip such images + logger.log( + "Img index {} is possibly corrupt. Removing it from the sample list".format( + img_index + ) + ) + del self.samples[img_index] + input_img = np.zeros(shape=(crop_size_h, crop_size_w, 3), dtype=np.uint8) + + data = {"image": input_img} + data = transform_fn(data) + + data["samples"] = data.pop("image") + data["targets"] = target + data["sample_id"] = img_index + + return data + + def __len__(self): + return len(self.samples) + + def __repr__(self): + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\tn_classes={}\n\ttransforms={}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.samples), + self.n_classes, + transforms_str, + ) diff --git a/Adaptive Frequency Filters/data/datasets/classification/imagenet_v2.py b/Adaptive Frequency Filters/data/datasets/classification/imagenet_v2.py new file mode 100644 index 0000000..a398f42 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/classification/imagenet_v2.py @@ -0,0 +1,174 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse +import tarfile +from pathlib import Path +from typing import Optional, Tuple, Dict, Union + +import torch + +from utils import logger +from utils.download_utils import get_local_path + +from .. import register_dataset +from ..dataset_base import BaseImageDataset +from ...transforms import image_pil as T + +IMAGENETv2_SPLIT_LINK_MAP = { + "matched_frequency": { + "url": "https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-matched-frequency.tar.gz", + "extracted_folder_name": "imagenetv2-matched-frequency-format-val", + }, + "threshold_0.7": { + "url": "https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-threshold0.7.tar.gz", + "extracted_folder_name": "imagenetv2-threshold0.7-format-val", + }, + "top_images": { + "url": "https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-top-images.tar.gz", + "extracted_folder_name": "imagenetv2-top-images-format-val", + }, +} + + +@register_dataset(name="imagenet_v2", task="classification") +class Imagenetv2Dataset(BaseImageDataset): + """ + `ImageNetv2 Dataset `_ for studying the robustness of models trained on ImageNet dataset + + Args: + opts: command-line arguments + is_training (Optional[bool]): ImageNetv2 should be used for evaluation only Default: False + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: True + + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = False, + is_evaluation: Optional[bool] = True, + *args, + **kwargs, + ) -> None: + if is_training: + logger.error( + "{} can only be used for evaluation".format(self.__class__.__name__) + ) + + super().__init__( + opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + + split = getattr(opts, "dataset.imagenet_v2.split", None) + if split is None or split not in IMAGENETv2_SPLIT_LINK_MAP.keys(): + logger.error( + "Please specify split for ImageNetv2. Supported ImageNetv2 splits are: {}".format( + IMAGENETv2_SPLIT_LINK_MAP.keys() + ) + ) + + split_path = get_local_path(opts, path=IMAGENETv2_SPLIT_LINK_MAP[split]["url"]) + with tarfile.open(split_path) as tf: + tf.extractall(self.root) + + root = Path( + "{}/{}".format( + self.root, IMAGENETv2_SPLIT_LINK_MAP[split]["extracted_folder_name"] + ) + ) + file_names = list(root.glob("**/*.jpeg")) + self.file_names = file_names + + setattr(opts, "dataset.collate_fn_name_train", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_val", "imagenet_collate_fn") + setattr(opts, "dataset.collate_fn_name_eval", "imagenet_collate_fn") + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Add dataset-specific arguments to the parser.""" + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--dataset.imagenet-v2.split", + type=str, + default="matched-frequency", + help="ImageNetv2 dataset. Possible choices are: {}".format( + [ + f"{i + 1}: {split_name}" + for i, split_name in enumerate(IMAGENETv2_SPLIT_LINK_MAP.keys()) + ] + ), + choices=IMAGENETv2_SPLIT_LINK_MAP.keys(), + ) + return parser + + def _validation_transforms(self, size: Union[Tuple, int], *args, **kwargs): + """ + Validation augmentation + Image --> Resize --> CenterCrop --> ToTensor + """ + aug_list = [ + T.Resize(opts=self.opts), + T.CenterCrop(opts=self.opts), + T.ToTensor(opts=self.opts), + ] + + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: + """ + :param batch_indexes_tup: Tuple of the form (Crop_size_W, Crop_size_H, Image_ID) + :return: dictionary containing input image, label, and sample_id. + """ + crop_size_h, crop_size_w, img_index = batch_indexes_tup + + # same for validation and evaluation + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + # infer target label from the file name + # file names are organized as SPLIT_NAME-format-val/class_idx/*.jpg + # Example: All images in this folder (imagenetv2-matched-frequency-format-val/0/*.jpg) belong to class 0 + img_path = str(self.file_names[img_index]) + target = int(self.file_names[img_index].parent.name) + + input_img = self.read_image_pil(img_path) + if input_img is None: + # Sometimes images are corrupt + # Skip such images + logger.log("Img index {} is possibly corrupt.".format(img_index)) + input_tensor = torch.zeros( + size=(3, crop_size_h, crop_size_w), dtype=self.img_dtype + ) + target = -1 + data = {"image": input_tensor} + else: + data = {"image": input_img} + data = transform_fn(data) + + data["samples"] = data["image"] + data["targets"] = target + data["sample_id"] = img_index + + return data + + def __len__(self) -> int: + return len(self.file_names) + + def __repr__(self) -> str: + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return "{}(\n\troot={}\n\tsamples={}\n\ttransforms={}\n)".format( + self.__class__.__name__, + self.root, + len(self.file_names), + transforms_str, + ) diff --git a/Adaptive Frequency Filters/data/datasets/dataset_base.py b/Adaptive Frequency Filters/data/datasets/dataset_base.py new file mode 100644 index 0000000..c7a720a --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/dataset_base.py @@ -0,0 +1,231 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import copy +import warnings +import torch +from torch import Tensor +from torch.utils import data +import cv2 +from PIL import Image +from typing import Optional, Union, Dict +import argparse +import psutil +import time +import numpy as np +from torchvision.io import ( + read_image, + read_file, + decode_jpeg, + ImageReadMode, + decode_image, +) +import io + +from utils import logger +from utils.ddp_utils import is_start_rank_node, is_master + + +class BaseImageDataset(data.Dataset): + """ + Base Dataset class for Image datasets + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ): + if getattr(opts, "dataset.trove.enable", False): + opts = self.load_from_server(opts=opts, is_training=is_training) + + root = ( + getattr(opts, "dataset.root_train", None) + if is_training + else getattr(opts, "dataset.root_val", None) + ) + self.root = root + self.is_training = is_training + self.is_evaluation = is_evaluation + self.sampler_name = getattr(opts, "sampler.name", None) + self.opts = opts + + image_device_cuda = getattr(self.opts, "dataset.decode_data_on_gpu", False) + device = getattr(self.opts, "dev.device", torch.device("cpu")) + use_cuda = False + if image_device_cuda and ( + (isinstance(device, str) and device.find("cuda") > -1) + or (isinstance(device, torch.device) and device.type.find("cuda") > -1) + ): # cuda could be cuda:0 + use_cuda = True + + if use_cuda and getattr(opts, "dataset.pin_memory", False): + if is_master(opts): + logger.error( + "For loading images on GPU, --dataset.pin-memory should be disabled." + ) + + self.device = device if use_cuda else torch.device("cpu") + + self.cached_data = ( + dict() + if getattr(opts, "dataset.cache_images_on_ram", False) and is_training + else None + ) + if self.cached_data is not None: + if not getattr(opts, "dataset.persistent_workers", False): + if is_master(opts): + logger.error( + "For caching, --dataset.persistent-workers should be enabled." + ) + + self.cache_limit = getattr(opts, "dataset.cache_limit", 80.0) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + return parser + + @staticmethod + def load_from_server(opts, is_training): + try: + from internal.utils.server_utils import load_from_data_server + + opts = load_from_data_server(opts=opts, is_training=is_training) + except ImportError as e: + import traceback + traceback.print_exc() + logger.error( + "Unable to load data. Please load data manually. Error: {}".format(e) + ) + + return opts + + def _training_transforms(self, *args, **kwargs): + raise NotImplementedError + + def _validation_transforms(self, *args, **kwargs): + raise NotImplementedError + + def _evaluation_transforms(self, *args, **kwargs): + raise NotImplementedError + + def read_image_pil(self, path: str, *args, **kwargs): + def convert_to_rgb(inp_data: Union[str, io.BytesIO]): + try: + rgb_img = Image.open(inp_data).convert("RGB") + except: + rgb_img = None + return rgb_img + + if self.cached_data is not None: + # code for caching data on RAM + used_memory = float(psutil.virtual_memory().percent) + + if path in self.cached_data: + img_byte = self.cached_data[path] + + elif (path not in self.cached_data) and (used_memory <= self.cache_limit): + # image is not present in cache and RAM usage is less than the threshold, add to cache + with open(path, "rb") as bin_file: + bin_file_data = bin_file.read() + img_byte = io.BytesIO(bin_file_data) + self.cached_data[path] = img_byte + else: + with open(path, "rb") as bin_file: + bin_file_data = bin_file.read() + img_byte = io.BytesIO(bin_file_data) # in-memory data + img = convert_to_rgb(img_byte) + else: + img = convert_to_rgb(path) + return img + + def read_pil_image_torchvision(self, path: str): + if self.cached_data is not None: + # code for caching data on RAM + used_memory = float(psutil.virtual_memory().percent) + + if path in self.cached_data: + byte_img = self.cached_data[path] + elif (path not in self.cached_data) and (used_memory <= self.cache_limit): + # image is not present in cache and RAM usage is less than the threshold, add to cache + byte_img = read_file(path) + self.cached_data[path] = byte_img + else: + byte_img = read_file(path) + else: + byte_img = read_file(path) + img = decode_image(byte_img, mode=ImageReadMode.RGB) + return img + + def read_image_tensor(self, path: str): + if self.cached_data is not None: + # code for caching data on RAM + used_memory = float(psutil.virtual_memory().percent) + + if path in self.cached_data: + byte_img = self.cached_data[path] + elif (path not in self.cached_data) and (used_memory <= self.cache_limit): + # image is not present in cache and RAM usage is less than the threshold, add to cache + byte_img = read_file(path) + self.cached_data[path] = byte_img + else: + byte_img = read_file(path) + else: + byte_img = read_file(path) + img = decode_jpeg(byte_img, device=self.device, mode=ImageReadMode.RGB) + return img + + @staticmethod + def read_mask_pil(path: str): + try: + mask = Image.open(path) + if mask.mode != "L": + logger.error("Mask mode should be L. Got: {}".format(mask.mode)) + return mask + except: + return None + + @staticmethod + def read_image_opencv(path: str): + warnings.warn( + "The use of read_image_opencv function is depreciated. Please use read_image_pil", + DeprecationWarning, + ) + return cv2.imread( + path, cv2.IMREAD_COLOR + ) # Image is read in BGR Format and not RGB format + + @staticmethod + def read_mask_opencv(path: str): + warnings.warn( + "The use of read_mask_opencv function is depreciated. Please use read_mask_pil", + DeprecationWarning, + ) + return cv2.imread(path, cv2.IMREAD_GRAYSCALE) + + @staticmethod + def convert_mask_to_tensor(mask): + # convert to tensor + mask = np.array(mask) + if len(mask.shape) > 2 and mask.shape[-1] > 1: + mask = np.ascontiguousarray(mask.transpose(2, 0, 1)) + return torch.as_tensor(mask, dtype=torch.long) + + @staticmethod + def adjust_mask_value(): + return 0 + + @staticmethod + def class_names(): + pass + + def __repr__(self): + return "{}(\n\troot={}\n\t is_training={})".format( + self.__class__.__name__, self.root, self.is_training + ) diff --git a/Adaptive Frequency Filters/data/datasets/detection/__init__.py b/Adaptive Frequency Filters/data/datasets/detection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Adaptive Frequency Filters/data/datasets/detection/coco_base.py b/Adaptive Frequency Filters/data/datasets/detection/coco_base.py new file mode 100644 index 0000000..3b1fbf5 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/detection/coco_base.py @@ -0,0 +1,343 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from pycocotools.coco import COCO +from pycocotools import mask as coco_mask +import os +from typing import Optional, Tuple, Dict, List +import numpy as np +import argparse + +from utils import logger + +from ...transforms import image_pil as T +from ...datasets import BaseImageDataset, register_dataset + + +@register_dataset(name="coco", task="detection") +class COCODetection(BaseImageDataset): + """ + Base class for the MS COCO Object Detection Dataset. + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + + .. note:: + This class implements basic functions (e.g., reading image and annotations), and does not implement + training/validation transforms. Detector specific sub-classes should extend this class and implement those + methods. See `coco_ssd.py` as an example for SSD. + + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + + split = "train" if is_training else "val" + year = 2017 + ann_file = os.path.join( + self.root, "annotations/instances_{}{}.json".format(split, year) + ) + + # disable printing, so that pycocotools print statements are not printed + logger.disable_printing() + + self.coco = COCO(ann_file) + self.img_dir = os.path.join(self.root, "{}{}".format(split, year)) + self.ids = ( + list(self.coco.imgToAnns.keys()) + if is_training + else list(self.coco.imgs.keys()) + ) + + coco_categories = sorted(self.coco.getCatIds()) + bkrnd_id = ( + 0 if getattr(opts, "dataset.detection.no_background_id", False) else 1 + ) + self.coco_id_to_contiguous_id = { + coco_id: i + bkrnd_id for i, coco_id in enumerate(coco_categories) + } + self.contiguous_id_to_coco_id = { + v: k for k, v in self.coco_id_to_contiguous_id.items() + } + self.num_classes = len(self.contiguous_id_to_coco_id.keys()) + bkrnd_id + + # enable printing + logger.enable_printing() + + setattr(opts, "model.detection.n_classes", self.num_classes) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--dataset.detection.no-background-id", + action="store_true", + help="Do not include background id", + ) + return parser + + def _training_transforms(self, size: tuple, ignore_idx: Optional[int] = 255): + """Training transforms should be implemented in sub-class""" + raise NotImplementedError + + def _validation_transforms(self, size: tuple, *args, **kwargs): + """Validation transforms should be implemented in sub-class""" + raise NotImplementedError + + def _evaluation_transforms(self, size: tuple, *args, **kwargs): + """Evaluation or Inference transforms (Resize (Optional) --> Tensor). + + .. note:: + Resizing the input to the same resolution as the detector's input is not enabled by default. + It can be enabled by passing **--evaluation.detection.resize-input-images** flag. + + """ + aug_list = [] + if getattr(self.opts, "evaluation.detection.resize_input_images", False): + aug_list.append(T.Resize(opts=self.opts, img_size=size)) + + aug_list.append(T.ToTensor(opts=self.opts)) + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def __getitem__(self, batch_indexes_tup: Tuple, *args, **kwargs) -> Dict: + crop_size_h, crop_size_w, img_index = batch_indexes_tup + + if self.is_training: + transform_fn = self._training_transforms(size=(crop_size_h, crop_size_w)) + elif self.is_evaluation: + transform_fn = self._evaluation_transforms(size=(crop_size_h, crop_size_w)) + else: # same for validation and evaluation + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + image_id = self.ids[img_index] + + image, img_name = self.get_image(image_id=image_id) + im_width, im_height = image.size + + boxes, labels, mask = self.get_boxes_and_labels( + image_id=image_id, + image_width=im_width, + image_height=im_height, + include_masks=True, + ) + + data = { + "image": image, + "box_labels": labels, + "box_coordinates": boxes, + "mask": mask, + } + + if transform_fn is not None: + data = transform_fn(data) + + output_data = { + "samples": { + "image": data["image"], + }, + "targets": { + "box_labels": data["box_labels"], + "box_coordinates": data["box_coordinates"], + "mask": data["mask"], + "image_id": torch.tensor(image_id), + "image_width": torch.tensor(im_width), + "image_height": torch.tensor(im_height), + }, + } + + return output_data + + def __len__(self): + return len(self.ids) + + def get_boxes_and_labels( + self, image_id, image_width, image_height, *args, include_masks=False, **kwargs + ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: + ann_ids = self.coco.getAnnIds(imgIds=image_id) + ann = self.coco.loadAnns(ann_ids) + + # filter crowd annotations + ann = [obj for obj in ann if obj["iscrowd"] == 0] + boxes = np.array( + [self._xywh2xyxy(obj["bbox"], image_width, image_height) for obj in ann], + np.float32, + ).reshape((-1, 4)) + labels = np.array( + [self.coco_id_to_contiguous_id[obj["category_id"]] for obj in ann], np.int64 + ).reshape((-1,)) + # remove invalid boxes + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + labels = labels[keep] + + masks = None + if include_masks: + masks = [] + for obj in ann: + rle = coco_mask.frPyObjects( + obj["segmentation"], image_height, image_width + ) + m = coco_mask.decode(rle) + if len(m.shape) < 3: + mask = m.astype(np.uint8) + else: + mask = (np.sum(m, axis=2) > 0).astype(np.uint8) + masks.append(mask) + + if len(masks) > 0: + masks = np.stack(masks, axis=0) + else: + masks = np.zeros(shape=(0, image_height, image_width), dtype=np.uint8) + masks = masks.astype(np.uint8) + masks = torch.from_numpy(masks) + masks = masks[keep] + assert len(boxes) == len(labels) == len(masks) + return boxes, labels, masks + else: + return boxes, labels, None + + def _xywh2xyxy(self, box, image_width, image_height) -> List: + x1, y1, w, h = box + return [ + max(0, x1), + max(0, y1), + min(x1 + w, image_width), + min(y1 + h, image_height), + ] + + def get_image(self, image_id: int) -> Tuple: + file_name = self.coco.loadImgs(image_id)[0]["file_name"] + image_file = os.path.join(self.img_dir, file_name) + image = self.read_image_pil(image_file) + return image, file_name + + def extra_repr(self) -> str: + return "" + + def __repr__(self) -> str: + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + elif self.is_evaluation: + transforms_str = self._evaluation_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + repr_str = ( + "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\ttransforms={}".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.ids), + transforms_str, + ) + ) + repr_str += self.extra_repr() + repr_str += "\n)" + return repr_str + + @staticmethod + def class_names() -> List: + return [ + "background", + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "backpack", + "umbrella", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted plant", + "bed", + "dining table", + "toilet", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", + ] diff --git a/Adaptive Frequency Filters/data/datasets/detection/coco_mask_rcnn.py b/Adaptive Frequency Filters/data/datasets/detection/coco_mask_rcnn.py new file mode 100644 index 0000000..bc0e7cd --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/detection/coco_mask_rcnn.py @@ -0,0 +1,151 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from typing import Optional, Tuple, Dict, List +import math +import argparse + +from .coco_base import COCODetection +from ...transforms import image_pil as T +from ...datasets import register_dataset +from ...collate_fns import register_collate_fn + + +@register_dataset(name="coco_mask_rcnn", task="detection") +class COCODetectionMaskRCNN(COCODetection): + """Dataset class for the MS COCO Object Detection using Mask RCNN . + + Args: + opts : + Command line arguments + is_training : bool + A flag used to indicate training or validation mode + is_evaluation : bool + A flag used to indicate evaluation (or inference) mode + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + + # set the collate functions for the dataset + setattr(opts, "dataset.collate_fn_name_train", "coco_mask_rcnn_collate_fn") + setattr(opts, "dataset.collate_fn_name_val", "coco_mask_rcnn_collate_fn") + setattr(opts, "dataset.collate_fn_name_eval", "coco_mask_rcnn_collate_fn") + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--dataset.detection.coco-mask-rcnn.use-lsj-aug", + action="store_true", + help="Use large scale jitter augmentation for training Mask RCNN model", + ) + + return parser + + def _training_transforms(self, size: tuple, ignore_idx: Optional[int] = 255): + """Training data augmentation methods + (Resize --> RandomHorizontalFlip --> ToTensor). + """ + + if getattr(self.opts, "dataset.detection.coco_mask_rcnn.use_lsj_aug", False): + aug_list = [ + T.ScaleJitter(opts=self.opts), + T.FixedSizeCrop(opts=self.opts), + T.RandomHorizontalFlip(opts=self.opts), + T.ToTensor(opts=self.opts), + ] + else: + aug_list = [ + T.Resize(opts=self.opts, img_size=size), + T.RandomHorizontalFlip(opts=self.opts), + T.ToTensor(opts=self.opts), + ] + + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: tuple, *args, **kwargs): + """Implements validation transformation method (Resize --> ToTensor).""" + aug_list = [ + T.Resize(opts=self.opts), + T.ToTensor(opts=self.opts), + ] + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def __getitem__(self, batch_indexes_tup: Tuple, *args, **kwargs) -> Dict: + crop_size_h, crop_size_w, img_index = batch_indexes_tup + + if self.is_training: + transform_fn = self._training_transforms(size=(crop_size_h, crop_size_w)) + else: # same for validation and evaluation + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + image_id = self.ids[img_index] + + image, img_name = self.get_image(image_id=image_id) + im_width, im_height = image.size + + boxes, labels, mask = self.get_boxes_and_labels( + image_id=image_id, + image_width=im_width, + image_height=im_height, + include_masks=True, + ) + + data = { + "image": image, + "box_labels": labels, + "box_coordinates": boxes, + "mask": mask, + } + + if transform_fn is not None: + data = transform_fn(data) + + output_data = { + "samples": { + "image": data["image"], + # PyTorch Mask RCNN implementation expect labels as an input. Because we do not want to change the + # the training infrastructure of affnet library, we pass labels as part of image key and + # handle it in the model. + "label": { + "labels": data["box_labels"], + "boxes": data["box_coordinates"], + "masks": data["mask"], + }, + }, + "targets": { + "image_id": torch.tensor(image_id), + "image_width": torch.tensor(im_width), + "image_height": torch.tensor(im_height), + }, + } + + return output_data + + +@register_collate_fn(name="coco_mask_rcnn_collate_fn") +def coco_mask_rcnn_collate_fn(batch: List, opts, *args, **kwargs) -> Dict: + new_batch = {"samples": {"image": [], "label": []}, "targets": []} + + for b_id, batch_ in enumerate(batch): + new_batch["samples"]["image"].append(batch_["samples"]["image"]) + new_batch["samples"]["label"].append(batch_["samples"]["label"]) + new_batch["targets"].append(batch_["targets"]) + + return new_batch diff --git a/Adaptive Frequency Filters/data/datasets/detection/coco_ssd.py b/Adaptive Frequency Filters/data/datasets/detection/coco_ssd.py new file mode 100644 index 0000000..0086409 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/detection/coco_ssd.py @@ -0,0 +1,225 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from typing import Optional, Tuple, Dict +import math +import argparse + +from utils import logger +from affnet.matcher_det import build_matcher +from affnet.anchor_generator import build_anchor_generator + +from .coco_base import COCODetection +from ...transforms import image_pil as T +from ...datasets import register_dataset +from ...collate_fns import register_collate_fn + + +@register_dataset(name="coco_ssd", task="detection") +class COCODetectionSSD(COCODetection): + """Dataset class for the MS COCO Object Detection using Single Shot Object Detector (SSD). + + Args: + opts : + Command line arguments + is_training : bool + A flag used to indicate training or validation mode + is_evaluation : bool + A flag used to indicate evaluation (or inference) mode + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + + anchor_gen_name = getattr(opts, "anchor_generator.name", None) + if anchor_gen_name is None or anchor_gen_name != "ssd": + logger.error("For SSD, we need --anchor-generator.name to be ssd") + + self.anchor_box_generator = build_anchor_generator(opts=opts, is_numpy=True) + + self.output_strides = self.anchor_box_generator.output_strides + + if getattr(opts, "matcher.name") != "ssd": + logger.error("For SSD, we need --matcher.name as ssd") + + self.match_prior = build_matcher(opts=opts) + + # set the collate functions for the dataset + setattr(opts, "dataset.collate_fn_name_train", "coco_ssd_collate_fn") + setattr(opts, "dataset.collate_fn_name_val", "coco_ssd_collate_fn") + setattr(opts, "dataset.collate_fn_name_eval", "coco_ssd_collate_fn") + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + return parser + + def _training_transforms(self, size: tuple, ignore_idx: Optional[int] = 255): + """Training data augmentation methods + (SSDCroping --> PhotometricDistort --> RandomHorizontalFlip -> Resize --> ToTensor). + """ + aug_list = [ + T.SSDCroping(opts=self.opts), + T.PhotometricDistort(opts=self.opts), + T.RandomHorizontalFlip(opts=self.opts), + T.Resize(opts=self.opts, img_size=size), + T.BoxPercentCoords(opts=self.opts), + T.ToTensor(opts=self.opts), + ] + + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: tuple, *args, **kwargs): + """Implements validation transformation method (Resize --> ToTensor).""" + aug_list = [ + T.Resize(opts=self.opts), + T.BoxPercentCoords(opts=self.opts), + T.ToTensor(opts=self.opts), + ] + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def generate_anchors(self, height, width): + """Generate anchors **on-the-fly** based on the input resolution.""" + anchors = [] + for output_stride in self.output_strides: + if output_stride == -1: + fm_width = fm_height = 1 + else: + fm_width = int(math.ceil(width / output_stride)) + fm_height = int(math.ceil(height / output_stride)) + fm_anchor = self.anchor_box_generator( + fm_height=fm_height, fm_width=fm_width, fm_output_stride=output_stride + ) + anchors.append(fm_anchor) + anchors = torch.cat(anchors, dim=0) + return anchors + + def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: + crop_size_h, crop_size_w, img_index = batch_indexes_tup + + if self.is_training: + transform_fn = self._training_transforms(size=(crop_size_h, crop_size_w)) + else: + # During evaluation, we use base class + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + image_id = self.ids[img_index] + + image, img_fname = self.get_image(image_id=image_id) + im_width, im_height = image.size + boxes, labels, _ = self.get_boxes_and_labels( + image_id=image_id, image_width=im_width, image_height=im_height + ) + + data = {"image": image, "box_labels": labels, "box_coordinates": boxes} + + data = transform_fn(data) + + # convert to priors + anchors = self.generate_anchors(height=crop_size_h, width=crop_size_w) + + gt_coordinates, gt_labels = self.match_prior( + gt_boxes=data["box_coordinates"], + gt_labels=data["box_labels"], + anchors=anchors, + ) + + output_data = { + "samples": {"image": data.pop("image")}, + "targets": { + "box_labels": gt_labels, + "box_coordinates": gt_coordinates, + "image_id": torch.tensor(image_id), + "image_width": torch.tensor(im_width), + "image_height": torch.tensor(im_height), + }, + } + + return output_data + + def __repr__(self): + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + elif self.is_evaluation: + transforms_str = self._evaluation_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\ttransforms={}\n\tmatcher={}\n\tanchor_gen={}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.ids), + transforms_str, + self.match_prior, + self.anchor_box_generator, + ) + + +@register_collate_fn(name="coco_ssd_collate_fn") +def coco_ssd_collate_fn(batch, opts): + new_batch = { + "samples": {"image": []}, + "targets": { + "box_labels": [], + "box_coordinates": [], + "image_id": [], + "image_width": [], + "image_height": [], + }, + } + + for b_id, batch_ in enumerate(batch): + # prepare inputs + new_batch["samples"]["image"].append(batch_["samples"]["image"]) + + # prepare outputs + new_batch["targets"]["box_labels"].append(batch_["targets"]["box_labels"]) + new_batch["targets"]["box_coordinates"].append( + batch_["targets"]["box_coordinates"] + ) + new_batch["targets"]["image_id"].append(batch_["targets"]["image_id"]) + new_batch["targets"]["image_width"].append(batch_["targets"]["image_width"]) + new_batch["targets"]["image_height"].append(batch_["targets"]["image_height"]) + + # stack inputs + new_batch["samples"]["image"] = torch.stack(new_batch["samples"]["image"], dim=0) + + # stack outputs + new_batch["targets"]["box_labels"] = torch.stack( + new_batch["targets"]["box_labels"], dim=0 + ) + + new_batch["targets"]["box_coordinates"] = torch.stack( + new_batch["targets"]["box_coordinates"], dim=0 + ) + + new_batch["targets"]["image_id"] = torch.stack( + new_batch["targets"]["image_id"], dim=0 + ) + + new_batch["targets"]["image_width"] = torch.stack( + new_batch["targets"]["image_width"], dim=0 + ) + + new_batch["targets"]["image_height"] = torch.stack( + new_batch["targets"]["image_height"], dim=0 + ) + + return new_batch diff --git a/Adaptive Frequency Filters/data/datasets/multi_modal_img_text/__init__.py b/Adaptive Frequency Filters/data/datasets/multi_modal_img_text/__init__.py new file mode 100644 index 0000000..5ecce33 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/multi_modal_img_text/__init__.py @@ -0,0 +1,41 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse + + +def arguments_multi_modal_img_text( + parser: argparse.ArgumentParser, +) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="Multi-modal image-text arguments", + description="Multi-modal image-text arguments", + ) + + group.add_argument( + "--dataset.multi-modal-img-text.zero-shot-eval", + action="store_true", + help="Use zero shot evaluation", + ) + + group.add_argument( + "--dataset.multi-modal-img-text.context-length", + type=int, + default=77, + help="Context length for the text model", + ) + + group.add_argument( + "--dataset.multi-modal-img-text.trunc-seq-len", + action="store_true", + help="Enable sequence length truncation", + ) + + from .zero_shot import arguments_zero_shot_dataset + + parser = arguments_zero_shot_dataset(parser) + + return parser diff --git a/Adaptive Frequency Filters/data/datasets/multi_modal_img_text/base_multi_modal_img_text.py b/Adaptive Frequency Filters/data/datasets/multi_modal_img_text/base_multi_modal_img_text.py new file mode 100644 index 0000000..0b3a9e4 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/multi_modal_img_text/base_multi_modal_img_text.py @@ -0,0 +1,419 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from typing import Optional, Tuple, Dict, List, Union +import torch +from torch import Tensor +import argparse +import ftfy +import re +import urllib +import os + +from utils import logger +from utils.ddp_utils import is_master, is_start_rank_node + +from ..dataset_base import BaseImageDataset +from ...transforms import image_pil as T +from ...collate_fns import register_collate_fn +from ...text_tokenizer import build_tokenizer +from .zero_shot import build_zero_shot_dataset + + +class BaseMultiModalImgText(BaseImageDataset): + """ + Base class for Image-Text multi-modal learning + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + + """ + + __separator = ":" + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + + super().__init__( + opts=opts, + is_training=is_training, + is_evaluation=is_evaluation, + *args, + **kwargs + ) + + self.is_master_node = is_master(opts) + self.is_start_rank_node = is_start_rank_node(opts) + + self.text_tokenizer = build_tokenizer(opts=opts, *args, **kwargs) + # CLIP models use a context length of 77 + self.context_length = getattr( + opts, "dataset.multi_modal_img_text.context_length", 77 + ) + + # for sharing padding index across the entire affnet framework, we will + # use a special variable "dataset.padding_index". The default value is set + # to 0. If you need to override the default value, then use + setattr(opts, "dataset.padding_index", None) + self.padding_index = getattr(opts, "dataset.padding_index", None) + + # Because padding index does not exist in vocab, we add 0 for padding index. + # So, we add 1 to total vocab size + vocab_size = self.text_tokenizer.get_vocab_size() + if vocab_size is None or vocab_size == -1: + logger.error( + "Vocab size can't be None or -1 in {}. Got: {}".format( + self.__class__.__name__, vocab_size + ) + ) + self.vocab_size = vocab_size + setattr(opts, "dataset.text_vocab_size", vocab_size) + setattr(opts, "dataset.text_context_length", self.context_length) + + setattr( + opts, "dataset.collate_fn_name_train", "multi_modal_img_text_collate_fn" + ) + setattr(opts, "dataset.collate_fn_name_val", "multi_modal_img_text_collate_fn") + setattr(opts, "dataset.collate_fn_name_eval", "multi_modal_img_text_collate_fn") + + self.zero_shot_dataset = self.get_zero_shot_dataset(*args, **kwargs) + self.cached_zero_shot_captions = None + + # Path where we will download data + self.cache_loc = os.path.join(self.root, ".img_text_tar_cache") + os.makedirs(self.cache_loc, exist_ok=True) + + self.dataset = self.get_dataset(*args, **kwargs) + + def get_zero_shot_dataset(self, *args, **kwargs): + zero_shot_eval = ( + False + if self.is_training + else getattr( + self.opts, "dataset.multi_modal_img_text.zero_shot_eval", False + ) + ) + if zero_shot_eval: + zero_shot_dataset = build_zero_shot_dataset(opts=self.opts, *args, **kwargs) + else: + zero_shot_dataset = None + return zero_shot_dataset + + def get_dataset(self, *args, **kwargs): + raise NotImplementedError + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Add dataset-specific arguments to the parser.""" + return parser + + def __len__(self): + raise NotImplementedError + + # if self.zeros_shot_dataset is not None: + # return len(self.zeros_shot_dataset) + # return len(self.dataset) + + def _transform_text(self, text_tensor: Tensor) -> Tuple[Tensor, int]: + captions_tensor = torch.zeros(size=(self.context_length,), dtype=torch.long) + + text_len = text_tensor.shape[0] + if text_len > self.context_length: + text_tensor = text_tensor[: self.context_length] + text_tensor[-1] = self.text_tokenizer.get_eot_token() + text_len = self.context_length + captions_tensor[:text_len] = text_tensor[:text_len] + return captions_tensor, text_len + + def _training_transforms(self, size: Union[Tuple, int], *args, **kwargs): + """ + Training data augmentation methods. + Image --> RandomResizedCrop --> RandomHorizontalFlip --> Optional(AutoAugment or RandAugment) + --> Tensor --> Optional(RandomErasing) --> Optional(MixUp) --> Optional(CutMix) + + .. note:: + 1. AutoAugment and RandAugment are mutually exclusive. + 2. Mixup and CutMix are applied on batches are implemented in trainer. + """ + aug_list = [ + T.RandomResizedCrop(opts=self.opts, size=size), + # T.RandomHorizontalFlip(opts=self.opts), + ] + auto_augment = getattr( + self.opts, "image_augmentation.auto_augment.enable", False + ) + rand_augment = getattr( + self.opts, "image_augmentation.rand_augment.enable", False + ) + if auto_augment and rand_augment: + logger.error( + "AutoAugment and RandAugment are mutually exclusive. Use either of them, but not both" + ) + elif auto_augment: + aug_list.append(T.AutoAugment(opts=self.opts)) + elif rand_augment: + aug_list.append(T.RandAugment(opts=self.opts)) + + aug_list.append(T.ToTensor(opts=self.opts)) + + if getattr(self.opts, "image_augmentation.random_erase.enable", False): + aug_list.append(T.RandomErasing(opts=self.opts)) + + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: Union[Tuple, int], *args, **kwargs): + """ + Validation augmentation + Image --> Resize --> CenterCrop --> ToTensor + """ + aug_list = [ + T.Resize(opts=self.opts), + T.CenterCrop(opts=self.opts), + T.ToTensor(opts=self.opts), + ] + + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _process_img_caption( + self, input_img, captions_str, img_transform_fn, zero_shot: bool + ) -> Tuple[Tensor, Tensor, int]: + data = {"image": input_img} + img_tensor = img_transform_fn(data)["image"] + + if zero_shot and self.cached_zero_shot_captions is not None: + return ( + img_tensor, + self.cached_zero_shot_captions[0], + self.cached_zero_shot_captions[1], + ) + + max_seq_len = 0 + # process caption + if isinstance(captions_str, str): + captions_tensor, max_seq_len = self._transform_text( + self.text_tokenizer(_caption_preprocessing(captions_str)) + ) + elif isinstance(captions_str, List): + captions_tensor = [] + for captions_str_i in captions_str: + if isinstance(captions_str_i, List): + # captions_str is [ [Num_templates_per_class] * Num_classes] + captions_tensor_i = [] + for ( + captions_str_i_j + ) in captions_str_i: # number of templates per class + seq, seq_len = self._transform_text( + self.text_tokenizer( + _caption_preprocessing(captions_str_i_j) + ) + ) + captions_tensor_i.append(seq) + max_seq_len = max(max_seq_len, seq_len) + captions_tensor_i = torch.stack(captions_tensor_i, dim=0) + captions_tensor.append(captions_tensor_i) + elif isinstance(captions_str_i, str): + # captions_str is [Num_templates_per_image] + seq, seq_len = self._transform_text( + self.text_tokenizer(_caption_preprocessing(captions_str_i)) + ) + captions_tensor.append(seq) + max_seq_len = max(max_seq_len, seq_len) + else: + raise NotImplementedError + # the shape of tensor is [Num_classes, captions_per_class, caption_length] + # or [Captions_per_image, caption_length] + captions_tensor = torch.stack(captions_tensor, dim=0) + else: + captions_tensor = None + logger.error( + "Captions should be either string, List[String] or List[List[str]]" + ) + + if zero_shot and self.cached_zero_shot_captions is None: + self.cached_zero_shot_captions = (captions_tensor, max_seq_len) + + return img_tensor, captions_tensor, max_seq_len + + def get_zero_shot_pair(self, img_index): + img_path, captions_str, class_label = self.zero_shot_dataset(img_index) + input_img = self.read_image_pil(img_path) + return input_img, captions_str, class_label + + def get_dataset_pair(self, img_index): + raise NotImplementedError + + def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: + """ + :param batch_indexes_tup: Tuple of the form (Crop_size_W, Crop_size_H, Image_ID) + :return: dictionary containing input image, label, and sample_id. + """ + crop_size_h, crop_size_w, img_index = batch_indexes_tup + if self.is_training: + transform_fn = self._training_transforms(size=(crop_size_h, crop_size_w)) + else: + # same for validation and evaluation + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + if self.zero_shot_dataset is not None: + # read captions and image path from conceptual captions dataset + # read captions and image path from zero-shot dataset + input_img, captions_str, class_label = self.get_zero_shot_pair( + img_index=img_index + ) + else: + input_img, captions_str, class_label = self.get_dataset_pair( + img_index=img_index + ) + + if input_img is None: + captions_tensor = torch.zeros(size=(self.context_length,), dtype=torch.long) + data = { + "samples": { + "image": torch.zeros(size=(3, crop_size_h, crop_size_w)), + "text": captions_tensor, + "padding_mask": (captions_tensor == self.padding_index) + if self.padding_index is not None + else None, + "max_seq_len": self.context_length, + }, + "targets": -1, + } + else: + img_tensor, captions_tensor, max_seq_len = self._process_img_caption( + input_img=input_img, + captions_str=captions_str, + img_transform_fn=transform_fn, + zero_shot=self.zero_shot_dataset is not None, + ) + + data = { + "samples": { + "image": img_tensor, + "text": captions_tensor, + "padding_mask": (captions_tensor == self.padding_index) + if self.padding_index is not None + else None, + "max_seq_len": max_seq_len, + }, + "targets": class_label, + } + + if self.zero_shot_dataset is not None: + data["zero_shot"] = 1 + + return data + + def extra_transform_repr(self): + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return "img_transforms={}".format(transforms_str) + + def __repr__(self): + return "{}(\n\troot={}\n\tis_training={}\n\tzero_shot={}\n\tn_samples={}\n\t{}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + self.zero_shot_dataset, + self.__len__(), + self.extra_transform_repr() + ) + + +def _caption_preprocessing(caption: str) -> str: + # captions may contain HTML tokens. Remove them + html_re = re.compile("<.*?>") + caption = urllib.parse.unquote(str(caption)) + caption = caption.replace("+", " ") + caption = re.sub(html_re, "", str(caption)) + # remove the next line + caption = caption.strip("\n") + # remove unwanted spaces + caption = re.sub(" +", " ", caption) + + caption = ftfy.fix_text(caption) + return caption.strip().lower() + + +@register_collate_fn(name="multi_modal_img_text_collate_fn") +def multi_modal_img_text_collate_fn(batch: List, opts) -> Dict: + images = [] + text_tokens = [] + padding_mask = [] + labels = [] + + truncate_seq_len = getattr( + opts, "dataset.multi_modal_img_text.trunc_seq_len", False + ) + + zero_shot = batch[0].pop("zero_shot", 0) + + max_seq_len_in_batch = 1 # at least one token is required in the sequence + for i, batch_i in enumerate(batch): + inputs_i = batch_i.pop("samples") + img_tensor = inputs_i.pop("image", None) + if img_tensor is None: + continue + images.append(img_tensor) + labels.append(batch_i.pop("targets")) + + text_data = inputs_i.pop("text") + pad_mask = inputs_i.pop("padding_mask", None) + max_seq_len_in_batch = max(max_seq_len_in_batch, inputs_i.pop("max_seq_len", 0)) + if zero_shot: + # For zero-shot, all text captions are the same + # so, we only aggregate for one batch element + if i == 0: + text_tokens.append(text_data) + if pad_mask is not None: + padding_mask.append(pad_mask) + else: + text_tokens.append(text_data) + if pad_mask is not None: + padding_mask.append(pad_mask) + + images = torch.stack(images, dim=0) + text_tokens = torch.stack(text_tokens, dim=0) + + # truncate tokens based on the max. seq length + if not truncate_seq_len: + max_seq_len_in_batch = -1 + text_tokens = text_tokens[..., :max_seq_len_in_batch] + + if len(padding_mask) != 0: + padding_mask = torch.stack(padding_mask, dim=0) + padding_mask = padding_mask[..., :max_seq_len_in_batch] + else: + padding_mask = None + + labels = torch.tensor(labels, dtype=torch.long) + + channels_last = getattr(opts, "common.channels_last", False) + if channels_last: + images = images.to(memory_format=torch.channels_last) + + return { + "samples": { + "image": images, + "text": text_tokens, + "padding_mask": padding_mask, + }, + "targets": labels, + } diff --git a/Adaptive Frequency Filters/data/datasets/multi_modal_img_text/img_text_tar_dataset.py b/Adaptive Frequency Filters/data/datasets/multi_modal_img_text/img_text_tar_dataset.py new file mode 100644 index 0000000..6370e74 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/multi_modal_img_text/img_text_tar_dataset.py @@ -0,0 +1,411 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +from typing import Optional, Dict +import numpy as np +import torch +import argparse +from PIL import Image, ImageFile +import io +import pickle +import multiprocessing +from multiprocessing.pool import Pool +import tarfile +import glob + +from utils import logger +from utils.download_utils import get_local_path +from utils.ddp_utils import dist_barrier + +from .. import register_dataset +from .base_multi_modal_img_text import BaseMultiModalImgText + +Image.MAX_IMAGE_PIXELS = None +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +def extract_content(tar_file, file_name): + f = tar_file.extractfile(file_name) + return f.read() + + +def decode_image(byte_data): + return Image.open(io.BytesIO(byte_data)).convert("RGB") + + +def decode_text(byte_data): + return byte_data.decode("utf-8") + + +def async_download_file_from_s3( + opts, tar_file_name: str, cache_loc: str, *args, **kwargs +) -> None: + # async download files form s3 + local_path = get_local_path( + opts=opts, + path=tar_file_name, + cache_loc=cache_loc, + quiet_download=True, + force_delete=False, + use_start_rank=False, + sync_ranks=False, + ) + + # now extract the tar file and save the content as each separate file + folder_name = local_path.replace(".tar.gz", "") + with tarfile.open(local_path, "r:gz") as tar_file: + tar_file.extractall(folder_name) + + # delete the tar file, to save space + if os.path.isfile(local_path): + os.remove(local_path) + + +@register_dataset(name="img_text_tar", task="multi_modal_img_text") +class ImgTextTarDataset(BaseMultiModalImgText): + """ + ImgTextTarDataset class for datasets that store Image-Text pairs as tar files, each tar file with multiple pairs. + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + + """ + + __separator = "-" + __overlap_ratio = 10 # we have an overlap ratio of 10 files + __file_extn = ".tar.gz" + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs, + ) -> None: + + super().__init__( + opts=opts, + is_training=is_training, + is_evaluation=is_evaluation, + *args, + **kwargs, + ) + self.zeros_shot_dataset = self.get_zero_shot_dataset() + + if is_training: + dataset_metadata = self.get_dataset() + + total_files = 0 + if -1 in dataset_metadata.keys(): + # At key=-1, we store the information about total files. + total_files = dataset_metadata.pop(-1) + + if total_files == 0: + logger.error( + "Total files can't be 0. Please check if metadata has key -1, which stores the total number of files" + ) + + self.dataset: Dict = dataset_metadata + self.total_pairs = total_files + self.dataset_keys = list(self.dataset.keys()) + + s3_bucket_path = getattr( + self.opts, + "dataset.multi_modal_img_text.img_text_tar.s3_bucket_path", + None, + ) + if s3_bucket_path is None: + if self.is_master_node: + logger.log( + "{} needs the path of AWS bucket where data is stored.".format( + self.__class__.__name__ + ) + ) + + self.s3_bucket_path = s3_bucket_path + + self._download_dataset() + + def get_dataset(self) -> Dict: + if self.is_training: + # read metadata file + # metadata file is a dictionary storing the start image-text ids along with the tar file name. + # Example {'0-18000': 'file_1.tar', 18000-29000', 'file_2.tar'} + metadata_file_loc = getattr( + self.opts, + "dataset.multi_modal_img_text.img_text_tar.metadata_file", + None, + ) + if metadata_file_loc is None: + if self.is_master_node: + logger.error( + "Please specify metadata file using " + "--dataset.multi-modal-img-text.img_text_tar.metadata-file for {}".format( + self.__class__.__name__ + ) + ) + + metadata_file_local_path = get_local_path( + self.opts, path=metadata_file_loc, force_delete=False + ) + with open(metadata_file_local_path, "rb") as fp: + metadata = pickle.load(fp) + return metadata + else: + return {} + + def _download_dataset(self): + if getattr(self.opts, "ddp.enable", False): + if self.is_start_rank_node: + logger.error( + "We need DDP for working with {} dataset".format( + self.__class__.__name__ + ) + ) + + # The total number of GPUs that a task is using is equal to the world size + world_size = getattr(self.opts, "ddp.world_size", -1) + if world_size is None or world_size == -1: + if self.is_start_rank_node: + logger.error("DDP world size should be greater than 1. Got: {}") + + # find the number of GPUs in each node + n_gpus_per_node = torch.cuda.device_count() + + # Total number of GPUs = Total number of nodes * number of GPUs per Node + n_nodes = max(1, world_size // n_gpus_per_node) + + # Find the node id based on current node rank + # node_id = current_node_rank / n_gpus_per_node + curr_node_rank = getattr(self.opts, "ddp.rank", None) + if curr_node_rank is None: + if self.is_start_rank_node: + logger.error("Node rank can't be None.") + node_id = curr_node_rank // n_gpus_per_node + + # Downloading the entire dataset on each node is not feasible. Instead, for each + # node, we will download a subset of the dataset and learn from it. + + # Split the dataset almost equally among all nodes. The length of this split + # is going to be the same as the number of nodes. + node_wise_dataset_split = np.array_split(self.dataset_keys, n_nodes) + + # download files corresponding to ith node + files_node_i = node_wise_dataset_split[node_id] + + # Dataset is organized as a dict where key corresponds to start_index of image-text pair and + # value corresponds to the file name. + + # find the start and end image-text pair indexes for node_i. + # Note that we overlap node_i and node_i+1 by at most 2 files + start_idx_node_i = max( + 0, self.dataset_keys.index(files_node_i[0]) - self.__overlap_ratio + ) + end_idx_node_i = min( + len(self.dataset_keys), + self.dataset_keys.index(files_node_i[-1]) + self.__overlap_ratio, + ) + + # Now, download the files concurrently using each rank on node i + # Now, download the files concurrently using each rank on node i + indexes_to_download_node_i = self.dataset_keys[start_idx_node_i:end_idx_node_i] + indexes_to_download_node_i_rank_j = np.array_split( + indexes_to_download_node_i, n_gpus_per_node + ) + total_files_to_download = len(indexes_to_download_node_i) + if self.is_start_rank_node: + logger.log(f"Starting to downloading {total_files_to_download} files") + + current_device = torch.cuda.current_device() + if getattr( + self.opts, + "dataset.multi_modal_img_text.img_text_tar.parallel_download", + False, + ): + # download concurrently using many workers for each rank + n_cpus = multiprocessing.cpu_count() + n_process_per_gpu = max( + 1, n_cpus // torch.cuda.device_count() + ) # max(1, min(4, n_cpus // torch.cuda.device_count())) + with Pool(processes=n_process_per_gpu) as pool: + pool.starmap( + async_download_file_from_s3, + [ + ( + self.opts, + os.path.join( + self.s3_bucket_path, self.dataset[img_text_idx] + ), + self.cache_loc, + ) + for img_text_idx in indexes_to_download_node_i_rank_j[ + current_device + ] + ], + ) + else: + # download sequentially (1 worker per rank) + for count, img_text_idx in enumerate( + indexes_to_download_node_i_rank_j[current_device] + ): + # Recall that dataset is organized as a dict where key corresponds to start_index of image-text pair + # value corresponds to the tar file name. + + async_download_file_from_s3( + opts=self.opts, + tar_file_name=os.path.join( + self.s3_bucket_path, self.dataset[img_text_idx] + ), + cache_loc=self.cache_loc, + ) + + if count % 100 == 0 and self.is_start_rank_node: + n_files_downloaded = len(glob.glob(f"{self.cache_loc}/*")) + print( + f"Progress: {n_files_downloaded}/{total_files_to_download}", + end="\r", + ) + + # synchronize between all DDP jobs + if getattr(self.opts, "ddp.use_distributed", False): + dist_barrier() + + if self.is_start_rank_node: + n_files_downloaded = len(glob.glob(f"{self.cache_loc}/*")) + logger.log( + f"Download complete ({n_files_downloaded}/{total_files_to_download}). " + f"Files are stored at: {self.cache_loc}" + ) + + def __len__(self): + if self.zeros_shot_dataset is not None: + return len(self.zeros_shot_dataset) + return self.total_pairs + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Add dataset-specific arguments to the parser.""" + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--dataset.multi-modal-img-text.img-text-tar.metadata-file", + type=str, + default=None, + help="Location of the metadata file", + ) + + group.add_argument( + "--dataset.multi-modal-img-text.img-text-tar.s3-bucket-path", + type=str, + default=None, + help="Path of the s3 bucket where data is stored", + ) + + group.add_argument( + "--dataset.multi-modal-img-text.img-text-tar.parallel-download", + action="store_true", + help="Download the data in parallel on each rank of the DDP process", + ) + + return parser + + def get_dataset_pair(self, img_index): + class_label = -1 + try: + + if img_index in self.dataset_keys: + # file index is the same as the start index + file_index = self.dataset_keys.index(img_index) + # data index is 0 because file index is one of the start indices + data_index = 0 + img_text_pair_id = self.dataset_keys[file_index] + else: + # find the index at which the element will be inserted. + # Example: If we have an array of start indices as [0, 15, 35, 90] and + # we want to find the position of image index 92, then the insertion index will be 4. + insertion_idx = np.searchsorted(self.dataset_keys, img_index) + + # the image id corresponding to 92 is stored in file whose start index is 90. + # So, the file index is one less than insertion index + file_index = insertion_idx - 1 + + img_text_pair_id = self.dataset_keys[file_index] + + # data index is delta between current value (92) and value at file index (90) + data_index = img_index - img_text_pair_id + + # get the key corresponding to file index and retrieve the file name + # concatenate the file name with cache location path + tar_file_name_from_metadata = self.dataset[img_text_pair_id] + + tar_file_name = os.path.join(self.cache_loc, tar_file_name_from_metadata) + # Tar file name is encoded as: / + # each file name in tar file is encoded as: and + """ + Example. + + img_text_tar_dataset/00000000_0_1000.tar.gz + |--- 00000000_0_image + |--- 00000000_0_text + |--- 00000000_1_image + |--- 00000000_1_text + |--- ... + + img_text_tar_dataset/00000000_1000_2000.tar.gz + |--- 00000000_1000_image + |--- 00000000_1000_text + |--- 00000000_1001_image + |--- 00000000_1001_text + |--- ... + """ + + # remove the tar extension because we have extracted the data when downloaded + tar_file_name = tar_file_name.replace(self.__file_extn, "") + + if not os.path.isdir(tar_file_name): + async_download_file_from_s3( + opts=self.opts, + tar_file_name=os.path.join( + self.s3_bucket_path, tar_file_name_from_metadata + ), + cache_loc=self.cache_loc, + ) + + # Based on this, decode the folder information + folder_name = tar_file_name.split(os.sep)[-1].split("_")[0] + + # adjust the data index with start_id offset + start_id = tar_file_name.split(os.sep)[-1].split("_")[1] + data_index = data_index + int(start_id) + + img_text_fname = f"{tar_file_name}/{folder_name}_{data_index}" + with open(f"{img_text_fname}_image", "rb") as img_byte_data: + input_img = decode_image(img_byte_data.read()) + + with open(f"{img_text_fname}_text", "rb") as text_byte_data: + captions_str = decode_text(text_byte_data.read()) + except Exception as e: + logger.log("error loading {}. Error message: {}".format(img_index, str(e))) + + input_img = None + captions_str = None + return input_img, captions_str, class_label + + def __repr__(self): + return "{}(\n\troot={}\n\t is_training={}\n\tzero_shot={}\n\tn_samples={}\n\t{}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + self.zeros_shot_dataset, + self.__len__(), + self.extra_transform_repr() + ) diff --git a/Adaptive Frequency Filters/data/datasets/multi_modal_img_text/zero_shot/__init__.py b/Adaptive Frequency Filters/data/datasets/multi_modal_img_text/zero_shot/__init__.py new file mode 100644 index 0000000..2662d5f --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/multi_modal_img_text/zero_shot/__init__.py @@ -0,0 +1,95 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import importlib +import argparse +import glob + +from utils.ddp_utils import is_master +from utils import logger + +from .base_zero_shot import BaseZeroShotDataset + + +ZERO_SHOT_DATASET_REGISTRY = {} + +SEPARATOR = ":" + + +def register_zero_shot_dataset(name): + """Helper function to register zero-shot datasets""" + + def register_zero_shot_dataset_class(cls): + if name in ZERO_SHOT_DATASET_REGISTRY: + raise ValueError( + "Cannot register duplicate zero-shot dataset class ({})".format(name) + ) + + if not issubclass(cls, BaseZeroShotDataset): + raise ValueError( + "Zero shot dataset ({}: {}) must extend BaseZeroShotDataset".format( + name, cls.__name__ + ) + ) + + ZERO_SHOT_DATASET_REGISTRY[name] = cls + return cls + + return register_zero_shot_dataset_class + + +def supported_zero_shot_dataset_str(dataset_name) -> None: + """Helper function to print error message in case zero shot dataset is not available""" + + supp_list = list(ZERO_SHOT_DATASET_REGISTRY.keys()) + supp_str = "Zero shot dataset ({}) is not yet supported. \n Supported datasets are:".format( + dataset_name + ) + for i, d_name in enumerate(supp_list): + supp_str += "\n\t\t{}: {}".format(i, d_name) + logger.error(supp_str + "\n") + + +def arguments_zero_shot_dataset( + parser: argparse.ArgumentParser, +) -> argparse.ArgumentParser: + """Helper function to get zero-shot dataset arguments""" + + parser = BaseZeroShotDataset.add_arguments(parser=parser) + + # add dataset specific arguments + for k, v in ZERO_SHOT_DATASET_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + return parser + + +def build_zero_shot_dataset(opts, *args, **kwargs): + """Helper function to build the zero shot datasets""" + zero_shot_dataset_name = getattr( + opts, "dataset.multi_modal_img_text.zero_shot.name", None + ) + + if zero_shot_dataset_name in list(ZERO_SHOT_DATASET_REGISTRY.keys()): + return ZERO_SHOT_DATASET_REGISTRY[zero_shot_dataset_name](opts, *args, **kwargs) + else: + supported_zero_shot_dataset_str(zero_shot_dataset_name) + + +# automatically import zero-shot datasets +dataset_dir = os.path.dirname(__file__) + +for file in os.listdir(dataset_dir): + path = os.path.join(dataset_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + zs_dataset_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module( + "data.datasets.multi_modal_img_text.zero_shot." + zs_dataset_name + ) diff --git a/Adaptive Frequency Filters/data/datasets/multi_modal_img_text/zero_shot/base_zero_shot.py b/Adaptive Frequency Filters/data/datasets/multi_modal_img_text/zero_shot/base_zero_shot.py new file mode 100644 index 0000000..34f1be9 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/multi_modal_img_text/zero_shot/base_zero_shot.py @@ -0,0 +1,47 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch.utils import data +import argparse +import time + +from utils import logger +from utils.ddp_utils import is_start_rank_node, dist_barrier + + +class BaseZeroShotDataset(object): + """ + Base Dataset class for Zero shot tasks + """ + + def __init__(self, opts, *args, **kwargs): + if getattr(opts, "dataset.multi_modal_img_text.zero_shot.trove.enable", False): + try: + from internal.utils.server_utils import load_from_data_server + + opts = load_from_data_server( + opts=opts, is_training=False, arg_prefix="dataset.multi_modal_img_text.zero_shot" + ) + except Exception as e: + logger.error("Unable to load from the server. Error: {}".format(str(e))) + + root = getattr(opts, "dataset.multi_modal_img_text.zero_shot.root_val", None) + self.root = root + self.opts = opts + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + return parser + + def __len__(self): + raise NotImplementedError + + @staticmethod + def class_names(): + pass + + def __repr__(self): + return "{}(root={})".format(self.__class__.__name__, self.root) diff --git a/Adaptive Frequency Filters/data/datasets/multi_modal_img_text/zero_shot/imagenet.py b/Adaptive Frequency Filters/data/datasets/multi_modal_img_text/zero_shot/imagenet.py new file mode 100644 index 0000000..e83c46e --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/multi_modal_img_text/zero_shot/imagenet.py @@ -0,0 +1,1143 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torchvision.datasets import ImageFolder +from typing import Tuple, List +import argparse + +from . import BaseZeroShotDataset, register_zero_shot_dataset + + +@register_zero_shot_dataset(name="imagenet") +class ImagenetDatasetZeroShot(BaseZeroShotDataset, ImageFolder): + """ + ImageNet Dataset for zero-shot evaluation + """ + + def __init__(self, opts, *args, **kwargs) -> None: + BaseZeroShotDataset.__init__(self, opts=opts, *args, **kwargs) + root = self.root + ImageFolder.__init__( + self, root=root, transform=None, target_transform=None, is_valid_file=None + ) + + n_classes = len(list(self.class_to_idx.keys())) + + class_names: List = self.class_names() + + templates = [] + for class_id in range(n_classes): + templates_class_i = clip_text_template(class_names[class_id].lower()) + templates.append(templates_class_i) + + self.text_templates = templates + + def __call__(self, img_index: int) -> Tuple[str, List[List[str]], int]: + """ + :param img_index: Index of the image + :return: Tuple[str, List[str]]: Tuple containing image path and list of captions + """ + + img_path, target = self.samples[img_index] + return img_path, self.text_templates, target + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + return parser + + def __len__(self) -> int: + return len(self.samples) + + @staticmethod + def class_names() -> List[str]: + """ImageNet class names""" + return [ + "tench", + "goldfish", + "great white shark", + "tiger shark", + "hammerhead shark", + "electric ray", + "stingray", + "rooster", + "hen", + "ostrich", + "brambling", + "goldfinch", + "house finch", + "junco", + "indigo bunting", + "American robin", + "bulbul", + "jay", + "magpie", + "chickadee", + "American dipper", + "kite (bird of prey)", + "bald eagle", + "vulture", + "great grey owl", + "fire salamander", + "smooth newt", + "newt", + "spotted salamander", + "axolotl", + "American bullfrog", + "tree frog", + "tailed frog", + "loggerhead sea turtle", + "leatherback sea turtle", + "mud turtle", + "terrapin", + "box turtle", + "banded gecko", + "green iguana", + "Carolina anole", + "desert grassland whiptail lizard", + "agama", + "frilled-necked lizard", + "alligator lizard", + "Gila monster", + "European green lizard", + "chameleon", + "Komodo dragon", + "Nile crocodile", + "American alligator", + "triceratops", + "worm snake", + "ring-necked snake", + "eastern hog-nosed snake", + "smooth green snake", + "kingsnake", + "garter snake", + "water snake", + "vine snake", + "night snake", + "boa constrictor", + "African rock python", + "Indian cobra", + "green mamba", + "sea snake", + "Saharan horned viper", + "eastern diamondback rattlesnake", + "sidewinder rattlesnake", + "trilobite", + "harvestman", + "scorpion", + "yellow garden spider", + "barn spider", + "European garden spider", + "southern black widow", + "tarantula", + "wolf spider", + "tick", + "centipede", + "black grouse", + "ptarmigan", + "ruffed grouse", + "prairie grouse", + "peafowl", + "quail", + "partridge", + "african grey parrot", + "macaw", + "sulphur-crested cockatoo", + "lorikeet", + "coucal", + "bee eater", + "hornbill", + "hummingbird", + "jacamar", + "toucan", + "duck", + "red-breasted merganser", + "goose", + "black swan", + "tusker", + "echidna", + "platypus", + "wallaby", + "koala", + "wombat", + "jellyfish", + "sea anemone", + "brain coral", + "flatworm", + "nematode", + "conch", + "snail", + "slug", + "sea slug", + "chiton", + "chambered nautilus", + "Dungeness crab", + "rock crab", + "fiddler crab", + "red king crab", + "American lobster", + "spiny lobster", + "crayfish", + "hermit crab", + "isopod", + "white stork", + "black stork", + "spoonbill", + "flamingo", + "little blue heron", + "great egret", + "bittern bird", + "crane bird", + "limpkin", + "common gallinule", + "American coot", + "bustard", + "ruddy turnstone", + "dunlin", + "common redshank", + "dowitcher", + "oystercatcher", + "pelican", + "king penguin", + "albatross", + "grey whale", + "killer whale", + "dugong", + "sea lion", + "Chihuahua", + "Japanese Chin", + "Maltese", + "Pekingese", + "Shih Tzu", + "King Charles Spaniel", + "Papillon", + "toy terrier", + "Rhodesian Ridgeback", + "Afghan Hound", + "Basset Hound", + "Beagle", + "Bloodhound", + "Bluetick Coonhound", + "Black and Tan Coonhound", + "Treeing Walker Coonhound", + "English foxhound", + "Redbone Coonhound", + "borzoi", + "Irish Wolfhound", + "Italian Greyhound", + "Whippet", + "Ibizan Hound", + "Norwegian Elkhound", + "Otterhound", + "Saluki", + "Scottish Deerhound", + "Weimaraner", + "Staffordshire Bull Terrier", + "American Staffordshire Terrier", + "Bedlington Terrier", + "Border Terrier", + "Kerry Blue Terrier", + "Irish Terrier", + "Norfolk Terrier", + "Norwich Terrier", + "Yorkshire Terrier", + "Wire Fox Terrier", + "Lakeland Terrier", + "Sealyham Terrier", + "Airedale Terrier", + "Cairn Terrier", + "Australian Terrier", + "Dandie Dinmont Terrier", + "Boston Terrier", + "Miniature Schnauzer", + "Giant Schnauzer", + "Standard Schnauzer", + "Scottish Terrier", + "Tibetan Terrier", + "Australian Silky Terrier", + "Soft-coated Wheaten Terrier", + "West Highland White Terrier", + "Lhasa Apso", + "Flat-Coated Retriever", + "Curly-coated Retriever", + "Golden Retriever", + "Labrador Retriever", + "Chesapeake Bay Retriever", + "German Shorthaired Pointer", + "Vizsla", + "English Setter", + "Irish Setter", + "Gordon Setter", + "Brittany dog", + "Clumber Spaniel", + "English Springer Spaniel", + "Welsh Springer Spaniel", + "Cocker Spaniel", + "Sussex Spaniel", + "Irish Water Spaniel", + "Kuvasz", + "Schipperke", + "Groenendael dog", + "Malinois", + "Briard", + "Australian Kelpie", + "Komondor", + "Old English Sheepdog", + "Shetland Sheepdog", + "collie", + "Border Collie", + "Bouvier des Flandres dog", + "Rottweiler", + "German Shepherd Dog", + "Dobermann", + "Miniature Pinscher", + "Greater Swiss Mountain Dog", + "Bernese Mountain Dog", + "Appenzeller Sennenhund", + "Entlebucher Sennenhund", + "Boxer", + "Bullmastiff", + "Tibetan Mastiff", + "French Bulldog", + "Great Dane", + "St. Bernard", + "husky", + "Alaskan Malamute", + "Siberian Husky", + "Dalmatian", + "Affenpinscher", + "Basenji", + "pug", + "Leonberger", + "Newfoundland dog", + "Great Pyrenees dog", + "Samoyed", + "Pomeranian", + "Chow Chow", + "Keeshond", + "brussels griffon", + "Pembroke Welsh Corgi", + "Cardigan Welsh Corgi", + "Toy Poodle", + "Miniature Poodle", + "Standard Poodle", + "Mexican hairless dog (xoloitzcuintli)", + "grey wolf", + "Alaskan tundra wolf", + "red wolf or maned wolf", + "coyote", + "dingo", + "dhole", + "African wild dog", + "hyena", + "red fox", + "kit fox", + "Arctic fox", + "grey fox", + "tabby cat", + "tiger cat", + "Persian cat", + "Siamese cat", + "Egyptian Mau", + "cougar", + "lynx", + "leopard", + "snow leopard", + "jaguar", + "lion", + "tiger", + "cheetah", + "brown bear", + "American black bear", + "polar bear", + "sloth bear", + "mongoose", + "meerkat", + "tiger beetle", + "ladybug", + "ground beetle", + "longhorn beetle", + "leaf beetle", + "dung beetle", + "rhinoceros beetle", + "weevil", + "fly", + "bee", + "ant", + "grasshopper", + "cricket insect", + "stick insect", + "cockroach", + "praying mantis", + "cicada", + "leafhopper", + "lacewing", + "dragonfly", + "damselfly", + "red admiral butterfly", + "ringlet butterfly", + "monarch butterfly", + "small white butterfly", + "sulphur butterfly", + "gossamer-winged butterfly", + "starfish", + "sea urchin", + "sea cucumber", + "cottontail rabbit", + "hare", + "Angora rabbit", + "hamster", + "porcupine", + "fox squirrel", + "marmot", + "beaver", + "guinea pig", + "common sorrel horse", + "zebra", + "pig", + "wild boar", + "warthog", + "hippopotamus", + "ox", + "water buffalo", + "bison", + "ram (adult male sheep)", + "bighorn sheep", + "Alpine ibex", + "hartebeest", + "impala (antelope)", + "gazelle", + "arabian camel", + "llama", + "weasel", + "mink", + "European polecat", + "black-footed ferret", + "otter", + "skunk", + "badger", + "armadillo", + "three-toed sloth", + "orangutan", + "gorilla", + "chimpanzee", + "gibbon", + "siamang", + "guenon", + "patas monkey", + "baboon", + "macaque", + "langur", + "black-and-white colobus", + "proboscis monkey", + "marmoset", + "white-headed capuchin", + "howler monkey", + "titi monkey", + "Geoffroy's spider monkey", + "common squirrel monkey", + "ring-tailed lemur", + "indri", + "Asian elephant", + "African bush elephant", + "red panda", + "giant panda", + "snoek fish", + "eel", + "silver salmon", + "rock beauty fish", + "clownfish", + "sturgeon", + "gar fish", + "lionfish", + "pufferfish", + "abacus", + "abaya", + "academic gown", + "accordion", + "acoustic guitar", + "aircraft carrier", + "airliner", + "airship", + "altar", + "ambulance", + "amphibious vehicle", + "analog clock", + "apiary", + "apron", + "trash can", + "assault rifle", + "backpack", + "bakery", + "balance beam", + "balloon", + "ballpoint pen", + "Band-Aid", + "banjo", + "baluster / handrail", + "barbell", + "barber chair", + "barbershop", + "barn", + "barometer", + "barrel", + "wheelbarrow", + "baseball", + "basketball", + "bassinet", + "bassoon", + "swimming cap", + "bath towel", + "bathtub", + "station wagon", + "lighthouse", + "beaker", + "military hat (bearskin or shako)", + "beer bottle", + "beer glass", + "bell tower", + "baby bib", + "tandem bicycle", + "bikini", + "ring binder", + "binoculars", + "birdhouse", + "boathouse", + "bobsleigh", + "bolo tie", + "poke bonnet", + "bookcase", + "bookstore", + "bottle cap", + "hunting bow", + "bow tie", + "brass memorial plaque", + "bra", + "breakwater", + "breastplate", + "broom", + "bucket", + "buckle", + "bulletproof vest", + "high-speed train", + "butcher shop", + "taxicab", + "cauldron", + "candle", + "cannon", + "canoe", + "can opener", + "cardigan", + "car mirror", + "carousel", + "tool kit", + "cardboard box / carton", + "car wheel", + "automated teller machine", + "cassette", + "cassette player", + "castle", + "catamaran", + "CD player", + "cello", + "mobile phone", + "chain", + "chain-link fence", + "chain mail", + "chainsaw", + "storage chest", + "chiffonier", + "bell or wind chime", + "china cabinet", + "Christmas stocking", + "church", + "movie theater", + "cleaver", + "cliff dwelling", + "cloak", + "clogs", + "cocktail shaker", + "coffee mug", + "coffeemaker", + "spiral or coil", + "combination lock", + "computer keyboard", + "candy store", + "container ship", + "convertible", + "corkscrew", + "cornet", + "cowboy boot", + "cowboy hat", + "cradle", + "construction crane", + "crash helmet", + "crate", + "infant bed", + "Crock Pot", + "croquet ball", + "crutch", + "cuirass", + "dam", + "desk", + "desktop computer", + "rotary dial telephone", + "diaper", + "digital clock", + "digital watch", + "dining table", + "dishcloth", + "dishwasher", + "disc brake", + "dock", + "dog sled", + "dome", + "doormat", + "drilling rig", + "drum", + "drumstick", + "dumbbell", + "Dutch oven", + "electric fan", + "electric guitar", + "electric locomotive", + "entertainment center", + "envelope", + "espresso machine", + "face powder", + "feather boa", + "filing cabinet", + "fireboat", + "fire truck", + "fire screen", + "flagpole", + "flute", + "folding chair", + "football helmet", + "forklift", + "fountain", + "fountain pen", + "four-poster bed", + "freight car", + "French horn", + "frying pan", + "fur coat", + "garbage truck", + "gas mask or respirator", + "gas pump", + "goblet", + "go-kart", + "golf ball", + "golf cart", + "gondola", + "gong", + "gown", + "grand piano", + "greenhouse", + "radiator grille", + "grocery store", + "guillotine", + "hair clip", + "hair spray", + "half-track", + "hammer", + "hamper", + "hair dryer", + "hand-held computer", + "handkerchief", + "hard disk drive", + "harmonica", + "harp", + "combine harvester", + "hatchet", + "holster", + "home theater", + "honeycomb", + "hook", + "hoop skirt", + "gymnastic horizontal bar", + "horse-drawn vehicle", + "hourglass", + "iPod", + "clothes iron", + "carved pumpkin", + "jeans", + "jeep", + "T-shirt", + "jigsaw puzzle", + "rickshaw", + "joystick", + "kimono", + "knee pad", + "knot", + "lab coat", + "ladle", + "lampshade", + "laptop computer", + "lawn mower", + "lens cap", + "letter opener", + "library", + "lifeboat", + "lighter", + "limousine", + "ocean liner", + "lipstick", + "slip-on shoe", + "lotion", + "music speaker", + "loupe magnifying glass", + "sawmill", + "magnetic compass", + "messenger bag", + "mailbox", + "tights", + "one-piece bathing suit", + "manhole cover", + "maraca", + "marimba", + "mask", + "matchstick", + "maypole", + "maze", + "measuring cup", + "medicine cabinet", + "megalith", + "microphone", + "microwave oven", + "military uniform", + "milk can", + "minibus", + "miniskirt", + "minivan", + "missile", + "mitten", + "mixing bowl", + "mobile home", + "ford model t", + "modem", + "monastery", + "monitor", + "moped", + "mortar and pestle", + "graduation cap", + "mosque", + "mosquito net", + "vespa", + "mountain bike", + "tent", + "computer mouse", + "mousetrap", + "moving van", + "muzzle", + "metal nail", + "neck brace", + "necklace", + "baby pacifier", + "notebook computer", + "obelisk", + "oboe", + "ocarina", + "odometer", + "oil filter", + "pipe organ", + "oscilloscope", + "overskirt", + "bullock cart", + "oxygen mask", + "product packet / packaging", + "paddle", + "paddle wheel", + "padlock", + "paintbrush", + "pajamas", + "palace", + "pan flute", + "paper towel", + "parachute", + "parallel bars", + "park bench", + "parking meter", + "railroad car", + "patio", + "payphone", + "pedestal", + "pencil case", + "pencil sharpener", + "perfume", + "Petri dish", + "photocopier", + "plectrum", + "Pickelhaube", + "picket fence", + "pickup truck", + "pier", + "piggy bank", + "pill bottle", + "pillow", + "ping-pong ball", + "pinwheel", + "pirate ship", + "drink pitcher", + "block plane", + "planetarium", + "plastic bag", + "plate rack", + "farm plow", + "plunger", + "Polaroid camera", + "pole", + "police van", + "poncho", + "pool table", + "soda bottle", + "plant pot", + "potter's wheel", + "power drill", + "prayer rug", + "printer", + "prison", + "missile", + "projector", + "hockey puck", + "punching bag", + "purse", + "quill", + "quilt", + "race car", + "racket", + "radiator", + "radio", + "radio telescope", + "rain barrel", + "recreational vehicle", + "fishing casting reel", + "reflex camera", + "refrigerator", + "remote control", + "restaurant", + "revolver", + "rifle", + "rocking chair", + "rotisserie", + "eraser", + "rugby ball", + "ruler measuring stick", + "sneaker", + "safe", + "safety pin", + "salt shaker", + "sandal", + "sarong", + "saxophone", + "scabbard", + "weighing scale", + "school bus", + "schooner", + "scoreboard", + "CRT monitor", + "screw", + "screwdriver", + "seat belt", + "sewing machine", + "shield", + "shoe store", + "shoji screen / room divider", + "shopping basket", + "shopping cart", + "shovel", + "shower cap", + "shower curtain", + "ski", + "balaclava ski mask", + "sleeping bag", + "slide rule", + "sliding door", + "slot machine", + "snorkel", + "snowmobile", + "snowplow", + "soap dispenser", + "soccer ball", + "sock", + "solar thermal collector", + "sombrero", + "soup bowl", + "keyboard space bar", + "space heater", + "space shuttle", + "spatula", + "motorboat", + "spider web", + "spindle", + "sports car", + "spotlight", + "stage", + "steam locomotive", + "through arch bridge", + "steel drum", + "stethoscope", + "scarf", + "stone wall", + "stopwatch", + "stove", + "strainer", + "tram", + "stretcher", + "couch", + "stupa", + "submarine", + "suit", + "sundial", + "sunglasses", + "sunglasses", + "sunscreen", + "suspension bridge", + "mop", + "sweatshirt", + "swim trunks / shorts", + "swing", + "electrical switch", + "syringe", + "table lamp", + "tank", + "tape player", + "teapot", + "teddy bear", + "television", + "tennis ball", + "thatched roof", + "front curtain", + "thimble", + "threshing machine", + "throne", + "tile roof", + "toaster", + "tobacco shop", + "toilet seat", + "torch", + "totem pole", + "tow truck", + "toy store", + "tractor", + "semi-trailer truck", + "tray", + "trench coat", + "tricycle", + "trimaran", + "tripod", + "triumphal arch", + "trolleybus", + "trombone", + "hot tub", + "turnstile", + "typewriter keyboard", + "umbrella", + "unicycle", + "upright piano", + "vacuum cleaner", + "vase", + "vaulted or arched ceiling", + "velvet fabric", + "vending machine", + "vestment", + "viaduct", + "violin", + "volleyball", + "waffle iron", + "wall clock", + "wallet", + "wardrobe", + "military aircraft", + "sink", + "washing machine", + "water bottle", + "water jug", + "water tower", + "whiskey jug", + "whistle", + "hair wig", + "window screen", + "window shade", + "Windsor tie", + "wine bottle", + "airplane wing", + "wok", + "wooden spoon", + "wool", + "split-rail fence", + "shipwreck", + "sailboat", + "yurt", + "website", + "comic book", + "crossword", + "traffic or street sign", + "traffic light", + "dust jacket", + "menu", + "plate", + "guacamole", + "consomme", + "hot pot", + "trifle", + "ice cream", + "popsicle", + "baguette", + "bagel", + "pretzel", + "cheeseburger", + "hot dog", + "mashed potatoes", + "cabbage", + "broccoli", + "cauliflower", + "zucchini", + "spaghetti squash", + "acorn squash", + "butternut squash", + "cucumber", + "artichoke", + "bell pepper", + "cardoon", + "mushroom", + "Granny Smith apple", + "strawberry", + "orange", + "lemon", + "fig", + "pineapple", + "banana", + "jackfruit", + "cherimoya (custard apple)", + "pomegranate", + "hay", + "carbonara", + "chocolate syrup", + "dough", + "meatloaf", + "pizza", + "pot pie", + "burrito", + "red wine", + "espresso", + "tea cup", + "eggnog", + "mountain", + "bubble", + "cliff", + "coral reef", + "geyser", + "lakeshore", + "promontory", + "sandbar", + "beach", + "valley", + "volcano", + "baseball player", + "bridegroom", + "scuba diver", + "rapeseed", + "daisy", + "yellow lady's slipper", + "corn", + "acorn", + "rose hip", + "horse chestnut seed", + "coral fungus", + "agaric", + "gyromitra", + "stinkhorn mushroom", + "earth star fungus", + "hen of the woods mushroom", + "bolete", + "corn cob", + "toilet paper", + ] + + +def clip_text_template(class_name): + return [ + f"a bad photo of a {class_name}.", + f"a photo of many {class_name}.", + f"a sculpture of a {class_name}.", + f"a photo of the hard to see {class_name}.", + f"a low resolution photo of the {class_name}.", + f"a rendering of a {class_name}.", + f"graffiti of a {class_name}.", + f"a bad photo of the {class_name}.", + f"a cropped photo of the {class_name}.", + f"a tattoo of a {class_name}.", + f"the embroidered {class_name}.", + f"a photo of a hard to see {class_name}.", + f"a bright photo of a {class_name}.", + f"a photo of a clean {class_name}.", + f"a photo of a dirty {class_name}.", + f"a dark photo of the {class_name}.", + f"a drawing of a {class_name}.", + f"a photo of my {class_name}.", + f"the plastic {class_name}.", + f"a photo of the cool {class_name}.", + f"a close-up photo of a {class_name}.", + f"a black and white photo of the {class_name}.", + f"a painting of the {class_name}.", + f"a painting of a {class_name}.", + f"a pixelated photo of the {class_name}.", + f"a sculpture of the {class_name}.", + f"a bright photo of the {class_name}.", + f"a cropped photo of a {class_name}.", + f"a plastic {class_name}.", + f"a photo of the dirty {class_name}.", + f"a jpeg corrupted photo of a {class_name}.", + f"a blurry photo of the {class_name}.", + f"a photo of the {class_name}.", + f"a good photo of the {class_name}.", + f"a rendering of the {class_name}.", + f"a {class_name} in a video game.", + f"a photo of one {class_name}.", + f"a doodle of a {class_name}.", + f"a close-up photo of the {class_name}.", + f"a photo of a {class_name}.", + f"the origami {class_name}.", + f"the {class_name} in a video game.", + f"a sketch of a {class_name}.", + f"a doodle of the {class_name}.", + f"a origami {class_name}.", + f"a low resolution photo of a {class_name}.", + f"the toy {class_name}.", + f"a rendition of the {class_name}.", + f"a photo of the clean {class_name}.", + f"a photo of a large {class_name}.", + f"a rendition of a {class_name}.", + f"a photo of a nice {class_name}.", + f"a photo of a weird {class_name}.", + f"a blurry photo of a {class_name}.", + f"a cartoon {class_name}.", + f"art of a {class_name}.", + f"a sketch of the {class_name}.", + f"a embroidered {class_name}.", + f"a pixelated photo of a {class_name}.", + f"itap of the {class_name}.", + f"a jpeg corrupted photo of the {class_name}.", + f"a good photo of a {class_name}.", + f"a plushie {class_name}.", + f"a photo of the nice {class_name}.", + f"a photo of the small {class_name}.", + f"a photo of the weird {class_name}.", + f"the cartoon {class_name}.", + f"art of the {class_name}.", + f"a drawing of the {class_name}.", + f"a photo of the large {class_name}.", + f"a black and white photo of a {class_name}.", + f"the plushie {class_name}.", + f"a dark photo of a {class_name}.", + f"itap of a {class_name}.", + f"graffiti of the {class_name}.", + f"a toy {class_name}.", + f"itap of my {class_name}.", + f"a photo of a cool {class_name}.", + f"a photo of a small {class_name}.", + f"a tattoo of the {class_name}.", + ] diff --git a/Adaptive Frequency Filters/data/datasets/segmentation/__init__.py b/Adaptive Frequency Filters/data/datasets/segmentation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Adaptive Frequency Filters/data/datasets/segmentation/ade20k.py b/Adaptive Frequency Filters/data/datasets/segmentation/ade20k.py new file mode 100644 index 0000000..afe4632 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/segmentation/ade20k.py @@ -0,0 +1,522 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +from typing import Optional, List, Dict, Tuple +import numpy as np + +from utils import logger + +from .. import register_dataset +from ..dataset_base import BaseImageDataset +from ...transforms import image_pil as T + + +@register_dataset(name="ade20k", task="segmentation") +class ADE20KDataset(BaseImageDataset): + """ + Dataset class for the ADE20K dataset + + The structure of the dataset should be something like this: :: + + ADEChallengeData2016/annotations/training/*.png + ADEChallengeData2016/annotations/validation/*.png + + ADEChallengeData2016/images/training/*.jpg + ADEChallengeData2016/images/validation/*.jpg + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + """ + + :param opts: arguments + :param is_training: Training or validation mode + :param is_evaluation: Evaluation mode + """ + super().__init__( + opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + root = self.root + + image_dir = os.path.join( + root, "images", "training" if is_training else "validation" + ) + annotation_dir = os.path.join( + root, "annotations", "training" if is_training else "validation" + ) + + images = [] + masks = [] + for file_name in os.listdir(image_dir): + if file_name.endswith(".jpg"): + img_f_name = "{}/{}".format(image_dir, file_name) + mask_f_name = "{}/{}".format( + annotation_dir, file_name.replace("jpg", "png") + ) + + if os.path.isfile(img_f_name) and os.path.isfile(mask_f_name): + images.append(img_f_name) + masks.append(mask_f_name) + + self.images = images + self.masks = masks + self.ignore_label = 255 + self.bgrnd_idx = 0 + setattr( + opts, "model.segmentation.n_classes", len(self.class_names()) - 1 + ) # ignore background + + # set the collate functions for the dataset + # For evaluation, we use PyTorch's default collate function. So, we set to collate_fn_name_eval to None + setattr(opts, "dataset.collate_fn_name_train", "default_collate_fn") + setattr(opts, "dataset.collate_fn_name_val", "default_collate_fn") + setattr(opts, "dataset.collate_fn_name_eval", None) + + def _training_transforms(self, size: tuple): + first_aug = T.RandomShortSizeResize(opts=self.opts) + aug_list = [ + T.RandomHorizontalFlip(opts=self.opts), + T.RandomCrop(opts=self.opts, size=size, ignore_idx=self.ignore_label), + ] + + if getattr(self.opts, "image_augmentation.random_gaussian_noise.enable", False): + aug_list.append(T.RandomGaussianBlur(opts=self.opts)) + + if getattr(self.opts, "image_augmentation.photo_metric_distort.enable", False): + aug_list.append(T.PhotometricDistort(opts=self.opts)) + + if getattr(self.opts, "image_augmentation.random_rotate.enable", False): + aug_list.append(T.RandomRotate(opts=self.opts)) + + if getattr(self.opts, "image_augmentation.random_order.enable", False): + new_aug_list = [ + first_aug, + T.RandomOrder(opts=self.opts, img_transforms=aug_list), + T.ToTensor(opts=self.opts), + ] + return T.Compose(opts=self.opts, img_transforms=new_aug_list) + else: + aug_list.insert(0, first_aug) + aug_list.append(T.ToTensor(opts=self.opts)) + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: tuple, *args, **kwargs): + aug_list = [T.Resize(opts=self.opts), T.ToTensor(opts=self.opts)] + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _evaluation_transforms(self, size: tuple, *args, **kwargs): + aug_list = [] + if getattr(self.opts, "evaluation.segmentation.resize_input_images", False): + # we want to resize while maintaining aspect ratio. So, we pass img_size argument to resize function + aug_list.append(T.Resize(opts=self.opts, img_size=min(size))) + + aug_list.append(T.ToTensor(opts=self.opts)) + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def __getitem__(self, batch_indexes_tup: Tuple[int, int, int]) -> Dict: + crop_size_h, crop_size_w, img_index = batch_indexes_tup + crop_size = (crop_size_h, crop_size_w) + + if self.is_training: + _transform = self._training_transforms(size=crop_size) + elif self.is_evaluation: + _transform = self._evaluation_transforms(size=crop_size) + else: + _transform = self._validation_transforms(size=crop_size) + + mask = self.read_mask_pil(self.masks[img_index]) + img = self.read_image_pil(self.images[img_index]) + + if (img.size[0] != mask.size[0]) or (img.size[1] != mask.size[1]): + logger.error( + "Input image and mask sizes are different. Input size: {} and Mask size: {}".format( + img.size, mask.size + ) + ) + + data = {"image": img} + if not self.is_evaluation: + data["mask"] = mask + + data = _transform(data) + + if self.is_evaluation: + # for evaluation purposes, resize only the input and not mask + data["mask"] = self.convert_mask_to_tensor(mask) + + output_data = { + "samples": data["image"], + "targets": data["mask"] - 1, # ignore background during training + } + + if self.is_evaluation: + im_width, im_height = img.size + img_name = self.images[img_index].split(os.sep)[-1].replace("jpg", "png") + mask = output_data.pop("targets") + output_data["targets"] = { + "mask": mask, + "file_name": img_name, + "im_width": im_width, + "im_height": im_height, + } + + return output_data + + @staticmethod + def adjust_mask_value(): + return 1 + + def __len__(self) -> int: + return len(self.images) + + @staticmethod + def color_palette() -> List: + color_codes = [ + [0, 0, 0], # background + [120, 120, 120], + [180, 120, 120], + [6, 230, 230], + [80, 50, 50], + [4, 200, 3], + [120, 120, 80], + [140, 140, 140], + [204, 5, 255], + [230, 230, 230], + [4, 250, 7], + [224, 5, 255], + [235, 255, 7], + [150, 5, 61], + [120, 120, 70], + [8, 255, 51], + [255, 6, 82], + [143, 255, 140], + [204, 255, 4], + [255, 51, 7], + [204, 70, 3], + [0, 102, 200], + [61, 230, 250], + [255, 6, 51], + [11, 102, 255], + [255, 7, 71], + [255, 9, 224], + [9, 7, 230], + [220, 220, 220], + [255, 9, 92], + [112, 9, 255], + [8, 255, 214], + [7, 255, 224], + [255, 184, 6], + [10, 255, 71], + [255, 41, 10], + [7, 255, 255], + [224, 255, 8], + [102, 8, 255], + [255, 61, 6], + [255, 194, 7], + [255, 122, 8], + [0, 255, 20], + [255, 8, 41], + [255, 5, 153], + [6, 51, 255], + [235, 12, 255], + [160, 150, 20], + [0, 163, 255], + [140, 140, 140], + [250, 10, 15], + [20, 255, 0], + [31, 255, 0], + [255, 31, 0], + [255, 224, 0], + [153, 255, 0], + [0, 0, 255], + [255, 71, 0], + [0, 235, 255], + [0, 173, 255], + [31, 0, 255], + [11, 200, 200], + [255, 82, 0], + [0, 255, 245], + [0, 61, 255], + [0, 255, 112], + [0, 255, 133], + [255, 0, 0], + [255, 163, 0], + [255, 102, 0], + [194, 255, 0], + [0, 143, 255], + [51, 255, 0], + [0, 82, 255], + [0, 255, 41], + [0, 255, 173], + [10, 0, 255], + [173, 255, 0], + [0, 255, 153], + [255, 92, 0], + [255, 0, 255], + [255, 0, 245], + [255, 0, 102], + [255, 173, 0], + [255, 0, 20], + [255, 184, 184], + [0, 31, 255], + [0, 255, 61], + [0, 71, 255], + [255, 0, 204], + [0, 255, 194], + [0, 255, 82], + [0, 10, 255], + [0, 112, 255], + [51, 0, 255], + [0, 194, 255], + [0, 122, 255], + [0, 255, 163], + [255, 153, 0], + [0, 255, 10], + [255, 112, 0], + [143, 255, 0], + [82, 0, 255], + [163, 255, 0], + [255, 235, 0], + [8, 184, 170], + [133, 0, 255], + [0, 255, 92], + [184, 0, 255], + [255, 0, 31], + [0, 184, 255], + [0, 214, 255], + [255, 0, 112], + [92, 255, 0], + [0, 224, 255], + [112, 224, 255], + [70, 184, 160], + [163, 0, 255], + [153, 0, 255], + [71, 255, 0], + [255, 0, 163], + [255, 204, 0], + [255, 0, 143], + [0, 255, 235], + [133, 255, 0], + [255, 0, 235], + [245, 0, 255], + [255, 0, 122], + [255, 245, 0], + [10, 190, 212], + [214, 255, 0], + [0, 204, 255], + [20, 0, 255], + [255, 255, 0], + [0, 153, 255], + [0, 41, 255], + [0, 255, 204], + [41, 0, 255], + [41, 255, 0], + [173, 0, 255], + [0, 245, 255], + [71, 0, 255], + [122, 0, 255], + [0, 255, 184], + [0, 92, 255], + [184, 255, 0], + [0, 133, 255], + [255, 214, 0], + [25, 194, 194], + [102, 255, 0], + [92, 0, 255], + ] + color_codes = np.asarray(color_codes).flatten() + return list(color_codes) + + @staticmethod + def class_names() -> List: + return [ + "background", + "wall", + "building", + "sky", + "floor", + "tree", + "ceiling", + "road", + "bed ", + "windowpane", + "grass", + "cabinet", + "sidewalk", + "person", + "earth", + "door", + "table", + "mountain", + "plant", + "curtain", + "chair", + "car", + "water", + "painting", + "sofa", + "shelf", + "house", + "sea", + "mirror", + "rug", + "field", + "armchair", + "seat", + "fence", + "desk", + "rock", + "wardrobe", + "lamp", + "bathtub", + "railing", + "cushion", + "base", + "box", + "column", + "signboard", + "chest of drawers", + "counter", + "sand", + "sink", + "skyscraper", + "fireplace", + "refrigerator", + "grandstand", + "path", + "stairs", + "runway", + "case", + "pool table", + "pillow", + "screen door", + "stairway", + "river", + "bridge", + "bookcase", + "blind", + "coffee table", + "toilet", + "flower", + "book", + "hill", + "bench", + "countertop", + "stove", + "palm", + "kitchen island", + "computer", + "swivel chair", + "boat", + "bar", + "arcade machine", + "hovel", + "bus", + "towel", + "light", + "truck", + "tower", + "chandelier", + "awning", + "streetlight", + "booth", + "television receiver", + "airplane", + "dirt track", + "apparel", + "pole", + "land", + "bannister", + "escalator", + "ottoman", + "bottle", + "buffet", + "poster", + "stage", + "van", + "ship", + "fountain", + "conveyer belt", + "canopy", + "washer", + "plaything", + "swimming pool", + "stool", + "barrel", + "basket", + "waterfall", + "tent", + "bag", + "minibike", + "cradle", + "oven", + "ball", + "food", + "step", + "tank", + "trade name", + "microwave", + "pot", + "animal", + "bicycle", + "lake", + "dishwasher", + "screen", + "blanket", + "sculpture", + "hood", + "sconce", + "vase", + "traffic light", + "tray", + "ashcan", + "fan", + "pier", + "crt screen", + "plate", + "monitor", + "bulletin board", + "shower", + "radiator", + "glass", + "clock", + "flag", + ] + + def __repr__(self) -> str: + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + elif self.is_evaluation: + transforms_str = self._evaluation_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return ( + "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\ttransforms={}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.images), + transforms_str, + ) + ) diff --git a/Adaptive Frequency Filters/data/datasets/segmentation/coco_segmentation.py b/Adaptive Frequency Filters/data/datasets/segmentation/coco_segmentation.py new file mode 100644 index 0000000..99eff75 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/segmentation/coco_segmentation.py @@ -0,0 +1,231 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +from typing import Optional, List, Dict, Union +import argparse + +from pycocotools.coco import COCO +from pycocotools import mask +import numpy as np +import os +from typing import Optional + + +from .. import register_dataset +from ..dataset_base import BaseImageDataset +from ...transforms import image_pil as T + + +@register_dataset("coco", "segmentation") +class COCODataset(BaseImageDataset): + """ + Dataset class for the COCO dataset that maps classes to PASCAL VOC classes + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + """ + + :param opts: arguments + :param is_training: Training or validation mode + :param is_evaluation: Evaluation mode + """ + super().__init__( + opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + year = 2017 + split = "train" if is_training else "val" + ann_file = os.path.join( + self.root, "annotations/instances_{}{}.json".format(split, year) + ) + self.img_dir = os.path.join(self.root, "images/{}{}".format(split, year)) + self.split = split + self.coco = COCO(ann_file) + self.coco_mask = mask + self.ids = list(self.coco.imgs.keys()) + + self.ignore_label = 255 + self.bgrnd_idx = 0 + + setattr(opts, "model.segmentation.n_classes", len(self.class_names())) + + def __getitem__(self, batch_indexes_tup): + crop_size_h, crop_size_w, img_index = batch_indexes_tup + crop_size = (crop_size_h, crop_size_w) + + if self.is_training: + _transform = self._training_transforms( + size=crop_size, ignore_idx=self.ignore_label + ) + elif self.is_evaluation: + _transform = self._evaluation_transforms(size=crop_size) + else: + _transform = self._validation_transforms(size=crop_size) + + coco = self.coco + img_id = self.ids[img_index] + img_metadata = coco.loadImgs(img_id)[0] + path = img_metadata["file_name"] + + rgb_img = self.read_image_opencv(os.path.join(self.img_dir, path)) + cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id)) + + im_height, im_width = rgb_img.shape[:2] + + mask = self._gen_seg_mask( + cocotarget, img_metadata["height"], img_metadata["width"] + ) + + data = {"image": rgb_img, "mask": None if self.is_evaluation else mask} + + data = _transform(data) + + if self.is_evaluation: + # for evaluation purposes, resize only the input and not mask + data["mask"] = mask + + output_data = {"samples": data["image"], "targets": data["mask"]} + + if self.is_evaluation: + img_name = path.replace("jpg", "png") + mask = output_data.pop("targets") + output_data["targets"] = { + "mask": mask, + "file_name": img_name, + "im_width": im_width, + "im_height": im_height, + } + + return output_data + + def _gen_seg_mask(self, target, h, w): + mask = np.zeros((h, w), dtype=np.uint8) + coco_mask = self.coco_mask + coco_to_pascal = self.coco_to_pascal_mapping() + for instance in target: + rle = coco_mask.frPyObjects(instance["segmentation"], h, w) + m = coco_mask.decode(rle) + cat = instance["category_id"] + if cat in coco_to_pascal: + c = coco_to_pascal.index(cat) + else: + continue + if len(m.shape) < 3: + mask[:, :] += (mask == 0) * (m * c) + else: + mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype( + np.uint8 + ) + return mask + + def _training_transforms(self, size: tuple, ignore_idx: Optional[int] = 255): + aug_list = [ + T.RandomResize(opts=self.opts), + T.RandomCrop(opts=self.opts, size=size), + T.RandomHorizontalFlip(opts=self.opts), + T.ToTensor(opts=self.opts), + ] + + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: tuple, *args, **kwargs): + aug_list = [T.Resize(opts=self.opts), T.ToTensor(opts=self.opts)] + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _evaluation_transforms(self, size: tuple, *args, **kwargs): + aug_list = [] + if getattr(self.opts, "evaluation.segmentation.resize_input_images", False): + aug_list.append(T.Resize(opts=self.opts)) + + aug_list.append(T.ToTensor(opts=self.opts)) + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def __len__(self): + return len(self.ids) + + @staticmethod + def class_names() -> List: + return [ + "background", + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "potted_plant", + "sheep", + "sofa", + "train", + "tv_monitor", + ] + + @staticmethod + def coco_to_pascal_mapping(): + return [ + 0, + 5, + 2, + 16, + 9, + 44, + 6, + 3, + 17, + 62, + 21, + 67, + 18, + 19, + 4, + 1, + 64, + 20, + 63, + 7, + 72, + ] + + def __repr__(self): + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + elif self.is_evaluation: + transforms_str = self._evaluation_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\t\n\ttransforms={}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.ids), + transforms_str, + ) diff --git a/Adaptive Frequency Filters/data/datasets/segmentation/pascal_voc.py b/Adaptive Frequency Filters/data/datasets/segmentation/pascal_voc.py new file mode 100644 index 0000000..b862745 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/segmentation/pascal_voc.py @@ -0,0 +1,276 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +from typing import Optional, List, Tuple, Dict +import argparse +import numpy as np + +from .. import register_dataset +from ..dataset_base import BaseImageDataset +from ...transforms import image_pil as T + + +@register_dataset("pascal", "segmentation") +class PascalVOCDataset(BaseImageDataset): + """ + Dataset class for the PASCAL VOC 2012 dataset + + The structure of PASCAL VOC dataset should be something like this: :: + + pascal_voc/VOCdevkit/VOC2012/Annotations + pascal_voc/VOCdevkit/VOC2012/JPEGImages + pascal_voc/VOCdevkit/VOC2012/SegmentationClass + pascal_voc/VOCdevkit/VOC2012/SegmentationClassAug_Visualization + pascal_voc/VOCdevkit/VOC2012/ImageSets + pascal_voc/VOCdevkit/VOC2012/list + pascal_voc/VOCdevkit/VOC2012/SegmentationClassAug + pascal_voc/VOCdevkit/VOC2012/SegmentationObject + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + use_coco_data = getattr(opts, "dataset.pascal.use_coco_data", False) + coco_root_dir = getattr(opts, "dataset.pascal.coco_root_dir", None) + root = self.root + + voc_root_dir = os.path.join(root, "VOC2012") + voc_list_dir = os.path.join(voc_root_dir, "list") + + coco_data_file = None + if self.is_training: + # use the PASCAL VOC 2012 train data with augmented data + data_file = os.path.join(voc_list_dir, "train_aug.txt") + if use_coco_data and coco_root_dir is not None: + coco_data_file = os.path.join(coco_root_dir, "train_2017.txt") + assert os.path.isfile( + coco_data_file + ), "COCO data file does not exist at: {}".format(coco_root_dir) + else: + data_file = os.path.join(voc_list_dir, "val.txt") + + self.images = [] + self.masks = [] + with open(data_file, "r") as lines: + for line in lines: + line_split = line.split(" ") + rgb_img_loc = voc_root_dir + os.sep + line_split[0].strip() + mask_img_loc = voc_root_dir + os.sep + line_split[1].strip() + assert os.path.isfile( + rgb_img_loc + ), "RGB file does not exist at: {}".format(rgb_img_loc) + assert os.path.isfile( + mask_img_loc + ), "Mask image does not exist at: {}".format(rgb_img_loc) + self.images.append(rgb_img_loc) + self.masks.append(mask_img_loc) + + # if you want to use Coarse data for training + if self.is_training and coco_data_file is not None: + with open(coco_data_file, "r") as lines: + for line in lines: + line_split = line.split(" ") + rgb_img_loc = coco_root_dir + os.sep + line_split[0].rstrip() + mask_img_loc = coco_root_dir + os.sep + line_split[1].rstrip() + # assert os.path.isfile(rgb_img_loc) + # assert os.path.isfile(mask_img_loc) + self.images.append(rgb_img_loc) + self.masks.append(mask_img_loc) + self.use_coco_data = use_coco_data + self.ignore_label = 255 + self.bgrnd_idx = 0 + setattr(opts, "model.segmentation.n_classes", len(self.class_names())) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--dataset.pascal.use-coco-data", + action="store_true", + help="Use MS-COCO data for training", + ) + group.add_argument( + "--dataset.pascal.coco-root-dir", + type=str, + default=None, + help="Location of MS-COCO data", + ) + return parser + + @staticmethod + def color_palette(): + color_codes = [ + [0, 0, 0], + [128, 0, 0], + [0, 128, 0], + [128, 128, 0], + [0, 0, 128], + [128, 0, 128], + [0, 128, 128], + [128, 128, 128], + [64, 0, 0], + [192, 0, 0], + [64, 128, 0], + [192, 128, 0], + [64, 0, 128], + [192, 0, 128], + [64, 128, 128], + [192, 128, 128], + [0, 64, 0], + [128, 64, 0], + [0, 192, 0], + [128, 192, 0], + [0, 64, 128], + ] + + color_codes = np.asarray(color_codes).flatten() + return list(color_codes) + + @staticmethod + def class_names() -> List: + return [ + "background", + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "potted_plant", + "sheep", + "sofa", + "train", + "tv_monitor", + ] + + def _training_transforms(self, size: tuple): + first_aug = T.RandomShortSizeResize(opts=self.opts) + aug_list = [ + T.RandomHorizontalFlip(opts=self.opts), + T.RandomCrop(opts=self.opts, size=size, ignore_idx=self.ignore_label), + ] + + if getattr(self.opts, "image_augmentation.random_gaussian_noise.enable", False): + aug_list.append(T.RandomGaussianBlur(opts=self.opts)) + + if getattr(self.opts, "image_augmentation.photo_metric_distort.enable", False): + aug_list.append(T.PhotometricDistort(opts=self.opts)) + + if getattr(self.opts, "image_augmentation.random_rotate.enable", False): + aug_list.append(T.RandomRotate(opts=self.opts)) + + if getattr(self.opts, "image_augmentation.random_order.enable", False): + new_aug_list = [ + first_aug, + T.RandomOrder(opts=self.opts, img_transforms=aug_list), + T.ToTensor(opts=self.opts), + ] + return T.Compose(opts=self.opts, img_transforms=new_aug_list) + else: + aug_list.insert(0, first_aug) + aug_list.append(T.ToTensor(opts=self.opts)) + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _validation_transforms(self, size: tuple, *args, **kwargs): + aug_list = [T.Resize(opts=self.opts), T.ToTensor(opts=self.opts)] + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def _evaluation_transforms(self, size: tuple, *args, **kwargs): + aug_list = [] + if getattr(self.opts, "evaluation.segmentation.resize_input_images", False): + # we want to resize while maintaining aspect ratio. So, we pass img_size argument to resize function + aug_list.append(T.Resize(opts=self.opts, img_size=min(size))) + + aug_list.append(T.ToTensor(opts=self.opts)) + return T.Compose(opts=self.opts, img_transforms=aug_list) + + def __getitem__(self, batch_indexes_tup: Tuple) -> Dict: + crop_size_h, crop_size_w, img_index = batch_indexes_tup + crop_size = (crop_size_h, crop_size_w) + + if self.is_training: + _transform = self._training_transforms(size=crop_size) + elif self.is_evaluation: + _transform = self._evaluation_transforms(size=crop_size) + else: + _transform = self._validation_transforms(size=crop_size) + + img = self.read_image_pil(self.images[img_index]) + mask = self.read_mask_pil(self.masks[img_index]) + + data = {"image": img} + if not self.is_evaluation: + data["mask"] = mask + + data = _transform(data) + + if self.is_evaluation: + # for evaluation purposes, resize only the input and not mask + data["mask"] = self.convert_mask_to_tensor(mask) + + output_data = {"samples": data["image"], "targets": data["mask"]} + + if self.is_evaluation: + im_width, im_height = img.size + img_name = self.images[img_index].split(os.sep)[-1].replace("jpg", "png") + mask = output_data.pop("targets") + output_data["targets"] = { + "mask": mask, + "file_name": img_name, + "im_width": im_width, + "im_height": im_height, + } + + return output_data + + def __len__(self): + return len(self.images) + + def __repr__(self): + from utils.tensor_utils import image_size_from_opts + + im_h, im_w = image_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + elif self.is_evaluation: + transforms_str = self._evaluation_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + return "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\tuse_coco={}\n\ttransforms={}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + len(self.images), + self.use_coco_data, + transforms_str, + ) diff --git a/Adaptive Frequency Filters/data/datasets/video_classification/__init__.py b/Adaptive Frequency Filters/data/datasets/video_classification/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Adaptive Frequency Filters/data/datasets/video_classification/kinetics.py b/Adaptive Frequency Filters/data/datasets/video_classification/kinetics.py new file mode 100644 index 0000000..ccdfa33 --- /dev/null +++ b/Adaptive Frequency Filters/data/datasets/video_classification/kinetics.py @@ -0,0 +1,287 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os.path +from typing import Optional, Tuple, List, Union +import torch +import pathlib +import glob +import argparse +import pickle + +from utils import logger +from utils.download_utils import get_local_path +from utils.ddp_utils import is_master + +from .. import register_dataset +from ..dataset_base import BaseImageDataset +from ...transforms import video as T +from ...video_reader import get_video_reader +from ...collate_fns import register_collate_fn + + +@register_dataset(name="kinetics", task="video_classification") +class KineticsDataset(BaseImageDataset): + """ + Dataset class for the Kinetics dataset + + Args: + opts: command-line arguments + is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True + is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False + """ + + def __init__( + self, + opts, + is_training: Optional[bool] = True, + is_evaluation: Optional[bool] = False, + *args, + **kwargs, + ) -> None: + + super(KineticsDataset, self).__init__( + opts=opts, is_training=is_training, is_evaluation=is_evaluation + ) + + if not os.path.isdir(self.root): + logger.error("Directory does not exist: {}".format(self.root)) + + pyav_video_reader = get_video_reader(opts=opts, is_training=is_training) + + if is_training: + metadata_file = getattr(opts, "dataset.kinetics.metadata_file_train", None) + else: + metadata_file = getattr(opts, "dataset.kinetics.metadata_file_val", None) + + if metadata_file is not None: + # internally, we take care that master node only downloads the file + metadata_file = get_local_path(opts=opts, path=metadata_file) + with open(metadata_file, "rb") as f: + self.samples = pickle.load(f) + assert isinstance(self.samples, List) + else: + # each folder is a class + class_names = sorted( + (f.name for f in pathlib.Path(self.root).iterdir() if f.is_dir()) + ) + + samples = [] + extensions = ["avi", "mp4"] + for cls_idx in range(len(class_names)): + cls_name = class_names[cls_idx] + class_folder = os.path.join(self.root, cls_name) + for video_path in glob.glob(f"{class_folder}/*"): + file_extn = video_path.split(".")[-1] + if ( + (file_extn in extensions) + and os.path.isfile(video_path) + and pyav_video_reader.check_video(filename=video_path) + ): + samples.append({"label": cls_idx, "video_path": video_path}) + self.samples = samples + results_loc = getattr(opts, "common.results_loc", None) + if is_master(opts): + stage = "train" if is_training else "val" + metadata_file_loc = f"{results_loc}/kinetics_metadata_{stage}.pkl" + + with open(metadata_file_loc, "wb") as f: + pickle.dump(self.samples, f) + logger.log("Metadata file saved at: {}".format(metadata_file_loc)) + + self.pyav_video_reader = pyav_video_reader + + def __len__(self): + return len(self.samples) + + def _training_transforms(self, size: tuple or int): + """ + + :param size: crop size (H, W) + :return: list of augmentation methods + """ + aug_list = [ + T.RandomResizedCrop(opts=self.opts, size=size), + T.RandomHorizontalFlip(opts=self.opts), + ] + return T.Compose(opts=self.opts, video_transforms=aug_list) + + def _validation_transforms(self, size: Union[Tuple, List, int]): + """ + + :param size: crop size (H, W) + :return: list of augmentation methods + """ + aug_list = [ + T.Resize(opts=self.opts), + T.CenterCrop(opts=self.opts, size=size), + ] + + return T.Compose(opts=self.opts, video_transforms=aug_list) + + def _evaluation_transforms(self, size: tuple): + """ + + :param size: crop size (H, W) + :return: list of augmentation methods + """ + return self._validation_transforms(size=size) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--dataset.kinetics.metadata-file-train", + type=str, + default=None, + help="Metadata file for kinetics train set", + ) + group.add_argument( + "--dataset.kinetics.metadata-file-val", + type=str, + default=None, + help="Metadata file for kinetics validation set", + ) + return parser + + def __getitem__(self, batch_indexes_tup): + ( + crop_size_h, + crop_size_w, + index, + n_frames_to_sample, + clips_per_video, + ) = batch_indexes_tup + if self.is_training: + transform_fn = self._training_transforms(size=(crop_size_h, crop_size_w)) + else: # same for validation and evaluation + transform_fn = self._validation_transforms(size=(crop_size_h, crop_size_w)) + + try: + info: dict = self.samples[index] + target = info["label"] + + # Default is Tensor of size [K, N, C, H, W]. + # If --dataset.kinetics.frame-stack-format="channel_first", then clip is of size [K, C, N, H, W] + # here, K --> no. of clips, C --> Image channels, N --> Number of frames per clip, H --> Height, W --> Width + input_video = self.pyav_video_reader.process_video( + vid_filename=info["video_path"], + n_frames_per_clip=n_frames_to_sample, + clips_per_video=clips_per_video, + video_transform_fn=transform_fn, + is_training=self.is_training, + ) + + if input_video is None: + logger.log("Corrupted video file: {}".format(info["video_path"])) + input_video = self.pyav_video_reader.dummy_video( + clips_per_video=clips_per_video, + n_frames_to_sample=n_frames_to_sample, + height=crop_size_h, + width=crop_size_w, + ) + + data = {"image": input_video} + target = getattr(self.opts, "loss.ignore_idx", -1) + else: + data = {"image": input_video} + + except Exception as e: + logger.log("Unable to load index: {}. Error: {}".format(index, str(e))) + input_video = self.pyav_video_reader.dummy_video( + clips_per_video=clips_per_video, + n_frames_to_sample=n_frames_to_sample, + height=crop_size_h, + width=crop_size_w, + ) + + target = getattr(self.opts, "loss.ignore_idx", -1) + data = {"image": input_video} + + output_data = { + "samples": data.pop("image"), + # target is a 0-dimensional tensor + "targets": torch.LongTensor(size=(input_video.shape[0],)).fill_(target), + } + + return output_data + + def __repr__(self): + from utils.tensor_utils import video_size_from_opts + + im_h, im_w, n_frames = video_size_from_opts(opts=self.opts) + + if self.is_training: + transforms_str = self._training_transforms(size=(im_h, im_w)) + else: + transforms_str = self._validation_transforms(size=(im_h, im_w)) + + if hasattr(self.pyav_video_reader, "frame_transforms_str"): + frame_transforms_str = self.pyav_video_reader.frame_transforms_str + else: + frame_transforms_str = None + + return "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n\tvideo_transforms={}\n\tframe_transforms={}\n)".format( + self.__class__.__name__, + self.root, + self.is_training, + self.__len__(), + transforms_str, + frame_transforms_str, + ) + + +@register_collate_fn(name="kinetics_collate_fn") +def kinetics_collate_fn(batch: List, opts): + batch_size = len(batch) + + images = [] + labels = [] + for b in range(batch_size): + b_label = batch[b]["targets"] + images.append(batch[b]["samples"]) + labels.append(b_label) + + images = torch.cat(images, dim=0) + labels = torch.cat(labels, dim=0) + + # check for contiguous + if not images.is_contiguous(): + images = images.contiguous() + + if not labels.is_contiguous(): + labels = labels.contiguous() + + return {"samples": images, "targets": labels} + + +@register_collate_fn(name="kinetics_collate_fn_train") +def kinetics_collate_fn_train(batch: List, opts): + batch_size = len(batch) + ignore_label = getattr(opts, "loss.ignore_idx", -1) + + images = [] + labels = [] + for b in range(batch_size): + b_label = batch[b]["targets"] + if ignore_label in b_label: + continue + images.append(batch[b]["samples"]) + labels.append(b_label) + + images = torch.cat(images, dim=0) + labels = torch.cat(labels, dim=0) + + # check for contiguous + if not images.is_contiguous(): + images = images.contiguous() + + if not labels.is_contiguous(): + labels = labels.contiguous() + + return {"samples": images, "targets": labels} diff --git a/Adaptive Frequency Filters/data/loader/__init__.py b/Adaptive Frequency Filters/data/loader/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Adaptive Frequency Filters/data/loader/dataloader.py b/Adaptive Frequency Filters/data/loader/dataloader.py new file mode 100644 index 0000000..062d99d --- /dev/null +++ b/Adaptive Frequency Filters/data/loader/dataloader.py @@ -0,0 +1,55 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from typing import Optional, Union, List +from torch.utils.data import DataLoader + +from ..sampler.base_sampler import BaseSamplerDP, BaseSamplerDDP +from ..datasets.dataset_base import BaseImageDataset + + +class affnetDataLoader(DataLoader): + """This class extends PyTorch's Dataloader""" + + def __init__( + self, + dataset: BaseImageDataset, + batch_size: int, + batch_sampler: Union[BaseSamplerDP, BaseSamplerDDP], + num_workers: Optional[int] = 1, + pin_memory: Optional[bool] = False, + persistent_workers: Optional[bool] = False, + collate_fn: Optional = None, + prefetch_factor: Optional[int] = 2, + *args, + **kwargs + ): + super(affnetDataLoader, self).__init__( + dataset=dataset, + batch_size=batch_size, + batch_sampler=batch_sampler, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + collate_fn=collate_fn, + prefetch_factor=prefetch_factor, + ) + + def update_indices(self, new_indices: List, *args, **kwargs): + """Update indices in the dataset class""" + if hasattr(self.batch_sampler, "img_indices") and hasattr( + self.batch_sampler, "update_indices" + ): + self.batch_sampler.update_indices(new_indices) + + def samples_in_dataset(self): + """Number of samples in the dataset""" + return len(self.batch_sampler.img_indices) + + def get_sample_indices(self) -> List: + """Sample IDs""" + return self.batch_sampler.img_indices diff --git a/Adaptive Frequency Filters/data/sampler/__init__.py b/Adaptive Frequency Filters/data/sampler/__init__.py new file mode 100644 index 0000000..74e8971 --- /dev/null +++ b/Adaptive Frequency Filters/data/sampler/__init__.py @@ -0,0 +1,117 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import importlib +from typing import Optional +from utils import logger +import argparse + +from utils.ddp_utils import is_master + +from .base_sampler import BaseSamplerDDP, BaseSamplerDP + +SAMPLER_REGISTRY = {} + + +def register_sampler(name): + def register_sampler_class(cls): + if name in SAMPLER_REGISTRY: + raise ValueError( + "Cannot register duplicate sampler class ({})".format(name) + ) + + if not (issubclass(cls, BaseSamplerDDP) or issubclass(cls, BaseSamplerDP)): + raise ValueError( + "Sampler ({}: {}) must extend BaseSamplerDDP or BaseSamplerDP".format( + name, cls.__name__ + ) + ) + + SAMPLER_REGISTRY[name] = cls + return cls + + return register_sampler_class + + +def build_sampler(opts, n_data_samples: int, is_training: Optional[bool] = False): + sampler_name = getattr(opts, "sampler.name", "variable_batch_sampler") + is_distributed = getattr(opts, "ddp.use_distributed", False) + + if is_distributed and sampler_name.split("_")[-1] != "ddp": + sampler_name = sampler_name + "_ddp" + + sampler = None + if sampler_name in SAMPLER_REGISTRY: + sampler = SAMPLER_REGISTRY[sampler_name]( + opts, n_data_samples=n_data_samples, is_training=is_training + ) + else: + supp_list = list(SAMPLER_REGISTRY.keys()) + supp_str = ( + "Sampler ({}) not yet supported. \n Supported optimizers are:".format( + sampler_name + ) + ) + for i, m_name in enumerate(supp_list): + supp_str += "\n\t {}: {}".format(i, logger.color_text(m_name)) + logger.error(supp_str) + + return sampler + + +def sampler_common_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--sampler.name", type=str, default="batch_sampler", help="Name of the sampler" + ) + parser.add_argument( + "--sampler.use-shards", + action="store_true", + help="Use data sharding. Only applicable to DDP", + ) + parser.add_argument( + "--sampler.num-repeats", + type=int, + default=1, + help="Repeat samples, as in repeated augmentation", + ) + + parser.add_argument( + "--sampler.truncated-repeat-aug-sampler", + action="store_true", + help="Use truncated repeated augmentation sampler", + ) + + parser.add_argument( + "--sampler.disable-shuffle-sharding", + action="store_true", + help="Disable shuffling while sharding for extremely large datasets", + ) + + return parser + + +def arguments_sampler(parser: argparse.ArgumentParser): + parser = sampler_common_args(parser=parser) + + # add classification specific arguments + for k, v in SAMPLER_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + + return parser + + +# automatically import the samplers +sampler_dir = os.path.dirname(__file__) +for file in os.listdir(sampler_dir): + path = os.path.join(sampler_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + sampler_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("data.sampler." + sampler_name) diff --git a/Adaptive Frequency Filters/data/sampler/base_sampler.py b/Adaptive Frequency Filters/data/sampler/base_sampler.py new file mode 100644 index 0000000..80625bb --- /dev/null +++ b/Adaptive Frequency Filters/data/sampler/base_sampler.py @@ -0,0 +1,296 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch.utils.data.sampler import Sampler +from typing import Optional +import torch.distributed as dist +import math +import argparse +import copy +import numpy as np +import random + + +class BaseSamplerDP(Sampler): + """ + Base class for DataParallel Sampler + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + # max between 1 and number of available GPUs. 1 because for supporting CPUs + n_gpus: int = max(1, torch.cuda.device_count()) + batch_size_gpu0: int = ( + getattr(opts, "dataset.train_batch_size0", 32) + if is_training + else getattr(opts, "dataset.val_batch_size0", 32) + ) + + n_samples_per_gpu = int(math.ceil(n_data_samples * 1.0 / n_gpus)) + total_size = n_samples_per_gpu * n_gpus + + indexes = [idx for idx in range(n_data_samples)] + # This ensures that we can divide the batches evenly across GPUs + indexes += indexes[: (total_size - n_data_samples)] + assert total_size == len(indexes) + + self.img_indices = indexes + self.n_samples = total_size + self.batch_size_gpu0 = batch_size_gpu0 + self.n_gpus = n_gpus + self.shuffle = True if is_training else False + self.epoch = 0 + + self.num_repeats = getattr(opts, "sampler.num_repeats", 1) if is_training else 1 + self.trunc_rep_aug = getattr( + opts, "sampler.truncated_repeat_aug_sampler", False + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + return parser + + def extra_repr(self): + extra_repr_str = "\n\t num_repeat={}" "\n\t trunc_rep_aug={}".format( + self.num_repeats, self.trunc_rep_aug + ) + return extra_repr_str + + def get_indices(self): + img_indices = copy.deepcopy(self.img_indices) + if self.shuffle: + random.seed(self.epoch) + random.shuffle(img_indices) + + if self.num_repeats > 1: + # Apply repeated augmentation + """Assume that we have [0, 1, 2, 3] samples. With repeated augmentation, + we first repeat the samples [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3] and then select 4 + samples [0, 0, 0, 1]. Note that we do shuffle at the beginning, so samples are not the + same at every iteration. + """ + n_samples_before_repeat = len(img_indices) + img_indices = np.repeat(img_indices, repeats=self.num_repeats) + img_indices = list(img_indices) + if self.trunc_rep_aug: + img_indices = img_indices[:n_samples_before_repeat] + return img_indices + + def __iter__(self): + raise NotImplementedError + + def __len__(self): + return len(self.img_indices) * (1 if self.trunc_rep_aug else self.num_repeats) + + def set_epoch(self, epoch): + self.epoch = epoch + + def update_scales(self, epoch, is_master_node=False, *args, **kwargs): + pass + + def update_indices(self, new_indices): + self.img_indices = new_indices + + def __repr__(self): + return "{}()".format(self.__class__.__name__) + + +class BaseSamplerDDP(Sampler): + """ + Base class for DistributedDataParallel Sampler + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + # max between 1 and number of available GPUs. 1 because for supporting CPUs + batch_size_gpu0: int = ( + getattr(opts, "dataset.train_batch_size0", 32) + if is_training + else getattr(opts, "dataset.val_batch_size0", 32) + ) + + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + + num_replicas = dist.get_world_size() + rank = dist.get_rank() + gpus_node_i = max(1, torch.cuda.device_count()) + + num_samples_per_replica = int(math.ceil(n_data_samples * 1.0 / num_replicas)) + total_size = num_samples_per_replica * num_replicas + + img_indices = [idx for idx in range(n_data_samples)] + img_indices += img_indices[: (total_size - n_data_samples)] + assert len(img_indices) == total_size + + self.img_indices = img_indices + self.n_samples_per_replica = num_samples_per_replica + self.shuffle = True if is_training else False + self.epoch = 0 + self.rank = rank + self.batch_size_gpu0 = batch_size_gpu0 + self.num_replicas = num_replicas + self.skip_sample_indices = [] + self.node_id = rank // gpus_node_i + + self.num_nodes = max(1, num_replicas // gpus_node_i) + self.local_rank = rank % gpus_node_i + self.num_gpus_node_i = gpus_node_i + + self.sharding = ( + getattr(opts, "sampler.use_shards", False) if is_training else False + ) + self.num_repeats = getattr(opts, "sampler.num_repeats", 1) if is_training else 1 + self.trunc_rep_aug = ( + getattr(opts, "sampler.truncated_repeat_aug_sampler", False) + if self.num_repeats + else False + ) + self.n_samples_per_replica = num_samples_per_replica * ( + 1 if self.trunc_rep_aug else self.num_repeats + ) + self.disable_shuffle_sharding = getattr( + opts, "sampler.disable_shuffle_sharding", False + ) + + def extra_repr(self): + extra_repr_str = ( + "\n\t num_repeat={}" + "\n\t trunc_rep_aug={}" + "\n\t sharding={}" + "\n\t disable_shuffle_sharding={}".format( + self.num_repeats, + self.trunc_rep_aug, + self.sharding, + self.disable_shuffle_sharding, + ) + ) + return extra_repr_str + + def get_indices_rank_i(self): + img_indices = copy.deepcopy(self.img_indices) + if self.shuffle: + random.seed(self.epoch) + + if self.sharding: + """If we have 8 samples, say [0, 1, 2, 3, 4, 5, 6, 7], and we have two nodes, + then node 0 will receive first 4 samples and node 1 will receive last 4 samples. + + note: + This strategy is useful when dataset is large and we want to process subset of dataset on each node. + """ + + # compute number pf samples per node. + # Each node may have multiple GPUs + # Node id = rank // num_gpus_per_rank + samples_per_node = int(math.ceil(len(img_indices) / self.num_nodes)) + indices_node_i = img_indices[ + self.node_id + * samples_per_node : (self.node_id + 1) + * samples_per_node + ] + + # Ensure that each node has equal number of samples + if len(indices_node_i) < samples_per_node: + indices_node_i += indices_node_i[ + : (samples_per_node - len(indices_node_i)) + ] + + # Note: For extremely large datasets, we may want to disable shuffling for efficient data loading + if not self.disable_shuffle_sharding: + # shuffle the indices within a node. + random.shuffle(indices_node_i) + + if self.num_repeats > 1: + """Assume that we have [0, 1, 2, 3] samples in rank_i. With repeated augmentation, + we first repeat the samples [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3] and then select 4 + samples [0, 0, 0, 1]. Note shuffling at the beginning + """ + # Apply repeated augmentation + n_samples_before_repeat = len(indices_node_i) + indices_node_i = np.repeat(indices_node_i, repeats=self.num_repeats) + indices_node_i = list(indices_node_i) + if self.trunc_rep_aug: + indices_node_i = indices_node_i[:n_samples_before_repeat] + + # divide the samples among each GPU in a node + indices_rank_i = indices_node_i[ + self.local_rank : len(indices_node_i) : self.num_gpus_node_i + ] + else: + """If we have 8 samples, say [0, 1, 2, 3, 4, 5, 6, 7], and we have two nodes, + then node 0 will receive [0, 2, 4, 6] and node 1 will receive [1, 3, 4, 7]. + + note: + This strategy is useful when each data sample is stored independently, and is + default in many frameworks + """ + random.shuffle(img_indices) + + if self.num_repeats > 1: + # Apply repeated augmentation + n_samples_before_repeat = len(img_indices) + img_indices = np.repeat(img_indices, repeats=self.num_repeats) + img_indices = list(img_indices) + if self.trunc_rep_aug: + img_indices = img_indices[:n_samples_before_repeat] + + # divide the samples among each GPU in a node + indices_rank_i = img_indices[ + self.rank : len(img_indices) : self.num_replicas + ] + else: + indices_rank_i = img_indices[ + self.rank : len(self.img_indices) : self.num_replicas + ] + return indices_rank_i + + def __iter__(self): + raise NotImplementedError + + def __len__(self): + return (len(self.img_indices) // self.num_replicas) * ( + 1 if self.trunc_rep_aug else self.num_repeats + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + return parser + + def set_epoch(self, epoch): + self.epoch = epoch + + def update_scales(self, epoch, is_master_node=False, *args, **kwargs): + pass + + def update_indices(self, new_indices): + self.img_indices = new_indices + + def __repr__(self): + return "{}()".format(self.__class__.__name__) diff --git a/Adaptive Frequency Filters/data/sampler/batch_sampler.py b/Adaptive Frequency Filters/data/sampler/batch_sampler.py new file mode 100644 index 0000000..397ca45 --- /dev/null +++ b/Adaptive Frequency Filters/data/sampler/batch_sampler.py @@ -0,0 +1,156 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- +import copy +import random +import argparse +from typing import Optional +import math +import numpy as np + +from common import DEFAULT_IMAGE_WIDTH, DEFAULT_IMAGE_HEIGHT + +from . import register_sampler, BaseSamplerDDP, BaseSamplerDP + + +@register_sampler(name="batch_sampler") +class BatchSampler(BaseSamplerDP): + """ + Standard Batch Sampler for data parallel + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, n_data_samples=n_data_samples, is_training=is_training + ) + crop_size_w: int = getattr( + opts, "sampler.bs.crop_size_width", DEFAULT_IMAGE_WIDTH + ) + crop_size_h: int = getattr( + opts, "sampler.bs.crop_size_height", DEFAULT_IMAGE_HEIGHT + ) + + self.crop_size_w = crop_size_w + self.crop_size_h = crop_size_h + + def __iter__(self): + img_indices = self.get_indices() + + start_index = 0 + batch_size = self.batch_size_gpu0 + n_samples = len(img_indices) + while start_index < n_samples: + + end_index = min(start_index + batch_size, n_samples) + batch_ids = img_indices[start_index:end_index] + start_index += batch_size + + if len(batch_ids) > 0: + batch = [ + (self.crop_size_h, self.crop_size_w, b_id) for b_id in batch_ids + ] + yield batch + + def __repr__(self): + repr_str = "{}(".format(self.__class__.__name__) + repr_str += "\n\tbase_im_size=(h={}, w={})" "\n\tbase_batch_size={}".format( + self.crop_size_h, self.crop_size_w, self.batch_size_gpu0 + ) + repr_str += self.extra_repr() + repr_str += "\n)" + return repr_str + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Batch sampler", description="Arguments related to Batch sampler" + ) + group.add_argument( + "--sampler.bs.crop-size-width", + default=DEFAULT_IMAGE_WIDTH, + type=int, + help="Base crop size (along width) during training", + ) + group.add_argument( + "--sampler.bs.crop-size-height", + default=DEFAULT_IMAGE_HEIGHT, + type=int, + help="Base crop size (along height) during training", + ) + return parser + + +@register_sampler(name="batch_sampler_ddp") +class BatchSamplerDDP(BaseSamplerDDP): + """ + Standard Batch Sampler for distributed data parallel + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, n_data_samples=n_data_samples, is_training=is_training + ) + crop_size_w: int = getattr( + opts, "sampler.bs.crop_size_width", DEFAULT_IMAGE_WIDTH + ) + crop_size_h: int = getattr( + opts, "sampler.bs.crop_size_height", DEFAULT_IMAGE_HEIGHT + ) + + self.crop_size_w = crop_size_w + self.crop_size_h = crop_size_h + + def __iter__(self): + indices_rank_i = self.get_indices_rank_i() + start_index = 0 + batch_size = self.batch_size_gpu0 + + n_samples_rank_i = len(indices_rank_i) + while start_index < n_samples_rank_i: + end_index = min(start_index + batch_size, n_samples_rank_i) + batch_ids = indices_rank_i[start_index:end_index] + n_batch_samples = len(batch_ids) + if n_batch_samples != batch_size: + batch_ids += indices_rank_i[: (batch_size - n_batch_samples)] + start_index += batch_size + + if len(batch_ids) > 0: + batch = [ + (self.crop_size_h, self.crop_size_w, b_id) for b_id in batch_ids + ] + yield batch + + def __repr__(self): + repr_str = "{}(".format(self.__class__.__name__) + repr_str += "\n\tbase_im_size=(h={}, w={})" "\n\tbase_batch_size={}".format( + self.crop_size_h, self.crop_size_w, self.batch_size_gpu0 + ) + repr_str += self.extra_repr() + repr_str += "\n)" + return repr_str diff --git a/Adaptive Frequency Filters/data/sampler/multi_scale_sampler.py b/Adaptive Frequency Filters/data/sampler/multi_scale_sampler.py new file mode 100644 index 0000000..70a9f32 --- /dev/null +++ b/Adaptive Frequency Filters/data/sampler/multi_scale_sampler.py @@ -0,0 +1,340 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- +import copy +import random +import argparse +from utils import logger +from typing import Optional +from common import DEFAULT_IMAGE_WIDTH, DEFAULT_IMAGE_HEIGHT +import numpy as np + +from . import register_sampler, BaseSamplerDP, BaseSamplerDDP +from .utils import _image_batch_pairs + + +@register_sampler(name="multi_scale_sampler") +class MultiScaleSampler(BaseSamplerDP): + """ + Multi-scale Batch Sampler for data parallel + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, n_data_samples=n_data_samples, is_training=is_training + ) + + crop_size_w: int = getattr( + opts, "sampler.msc.crop_size_width", DEFAULT_IMAGE_WIDTH + ) + crop_size_h: int = getattr( + opts, "sampler.msc.crop_size_height", DEFAULT_IMAGE_HEIGHT + ) + + min_crop_size_w: int = getattr(opts, "sampler.msc.min_crop_size_width", 160) + max_crop_size_w: int = getattr(opts, "sampler.msc.max_crop_size_width", 320) + + min_crop_size_h: int = getattr(opts, "sampler.msc.min_crop_size_height", 160) + max_crop_size_h: int = getattr(opts, "sampler.msc.max_crop_size_height", 320) + + scale_inc: bool = getattr(opts, "sampler.msc.scale_inc", False) + scale_ep_intervals: list or int = getattr( + opts, "sampler.msc.ep_intervals", [40] + ) + scale_inc_factor: float = getattr(opts, "sampler.msc.scale_inc_factor", 0.25) + + check_scale_div_factor: int = getattr(opts, "sampler.msc.check_scale", 32) + max_img_scales: int = getattr(opts, "sampler.msc.max_n_scales", 10) + + if isinstance(scale_ep_intervals, int): + scale_ep_intervals = [scale_ep_intervals] + + self.min_crop_size_w = min_crop_size_w + self.max_crop_size_w = max_crop_size_w + self.min_crop_size_h = min_crop_size_h + self.max_crop_size_h = max_crop_size_h + + self.crop_size_w = crop_size_w + self.crop_size_h = crop_size_h + + self.scale_inc_factor = scale_inc_factor + self.scale_ep_intervals = scale_ep_intervals + + self.max_img_scales = max_img_scales + self.check_scale_div_factor = check_scale_div_factor + self.scale_inc = scale_inc + + if is_training: + self.img_batch_tuples = _image_batch_pairs( + crop_size_h=self.crop_size_h, + crop_size_w=self.crop_size_w, + batch_size_gpu0=self.batch_size_gpu0, + n_gpus=self.n_gpus, + max_scales=self.max_img_scales, + check_scale_div_factor=self.check_scale_div_factor, + min_crop_size_w=self.min_crop_size_w, + max_crop_size_w=self.max_crop_size_w, + min_crop_size_h=self.min_crop_size_h, + max_crop_size_h=self.max_crop_size_h, + ) + # over-ride the batch-size + self.img_batch_tuples = [ + (h, w, self.batch_size_gpu0) for h, w, b in self.img_batch_tuples + ] + else: + self.img_batch_tuples = [(crop_size_h, crop_size_w, self.batch_size_gpu0)] + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Multi-scale sampler", description="Multi-scale sampler" + ) + group.add_argument( + "--sampler.msc.crop-size-width", + default=DEFAULT_IMAGE_WIDTH, + type=int, + help="Base crop size (along width) during training", + ) + group.add_argument( + "--sampler.msc.crop-size-height", + default=DEFAULT_IMAGE_HEIGHT, + type=int, + help="Base crop size (along height) during training", + ) + + group.add_argument( + "--sampler.msc.min-crop-size-width", + default=160, + type=int, + help="Min. crop size along width during training", + ) + group.add_argument( + "--sampler.msc.max-crop-size-width", + default=320, + type=int, + help="Max. crop size along width during training", + ) + + group.add_argument( + "--sampler.msc.min-crop-size-height", + default=160, + type=int, + help="Min. crop size along height during training", + ) + group.add_argument( + "--sampler.msc.max-crop-size-height", + default=320, + type=int, + help="Max. crop size along height during training", + ) + group.add_argument( + "--sampler.msc.max-n-scales", + default=5, + type=int, + help="Max. scales in variable batch sampler. For example, [0.25, 0.5, 0.75, 1, 1.25] ", + ) + group.add_argument( + "--sampler.msc.check-scale", + default=32, + type=int, + help="Image scales should be divisible by this factor", + ) + group.add_argument( + "--sampler.msc.ep-intervals", + default=[40], + type=int, + help="Epoch intervals at which scales are adjusted", + ) + group.add_argument( + "--sampler.msc.scale-inc-factor", + default=0.25, + type=float, + help="Factor by which we should increase the scale", + ) + group.add_argument( + "--sampler.msc.scale-inc", + action="store_true", + help="Increase image scales during training", + ) + + return parser + + def __iter__(self): + img_indices = self.get_indices() + start_index = 0 + n_samples = len(img_indices) + while start_index < n_samples: + crop_h, crop_w, batch_size = random.choice(self.img_batch_tuples) + + end_index = min(start_index + batch_size, n_samples) + batch_ids = img_indices[start_index:end_index] + n_batch_samples = len(batch_ids) + if len(batch_ids) != batch_size: + batch_ids += img_indices[: (batch_size - n_batch_samples)] + start_index += batch_size + + if len(batch_ids) > 0: + batch = [(crop_h, crop_w, b_id) for b_id in batch_ids] + yield batch + + def update_scales(self, epoch, is_master_node=False, *args, **kwargs): + pass + + def __repr__(self): + repr_str = "{}(".format(self.__class__.__name__) + repr_str += ( + "\n\t base_im_size=(h={}, w={}), " + "\n\t base_batch_size={} " + "\n\t scales={} " + "\n\t scale_inc={} " + "\n\t scale_inc_factor={} " + "\n\t ep_intervals={}".format( + self.crop_size_h, + self.crop_size_w, + self.batch_size_gpu0, + self.img_batch_tuples, + self.scale_inc, + self.scale_inc_factor, + self.scale_ep_intervals, + ) + ) + repr_str += self.extra_repr() + repr_str += "\n)" + return repr_str + + +@register_sampler(name="multi_scale_sampler_ddp") +class MultiScaleSamplerDDP(BaseSamplerDDP): + """ + Multi-scale Batch Sampler for distributed data parallel + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, n_data_samples=n_data_samples, is_training=is_training + ) + crop_size_w: int = getattr( + opts, "sampler.msc.crop_size_width", DEFAULT_IMAGE_WIDTH + ) + crop_size_h: int = getattr( + opts, "sampler.msc.crop_size_height", DEFAULT_IMAGE_HEIGHT + ) + + min_crop_size_w: int = getattr(opts, "sampler.msc.min_crop_size_width", 160) + max_crop_size_w: int = getattr(opts, "sampler.msc.max_crop_size_width", 320) + + min_crop_size_h: int = getattr(opts, "sampler.msc.min_crop_size_height", 160) + max_crop_size_h: int = getattr(opts, "sampler.msc.max_crop_size_height", 320) + + scale_inc: bool = getattr(opts, "sampler.msc.scale_inc", False) + scale_ep_intervals: list or int = getattr( + opts, "sampler.msc.ep_intervals", [40] + ) + scale_inc_factor: float = getattr(opts, "sampler.msc.scale_inc_factor", 0.25) + check_scale_div_factor: int = getattr(opts, "sampler.msc.check_scale", 32) + + max_img_scales: int = getattr(opts, "sampler.msc.max_n_scales", 10) + + self.crop_size_h = crop_size_h + self.crop_size_w = crop_size_w + self.min_crop_size_h = min_crop_size_h + self.max_crop_size_h = max_crop_size_h + self.min_crop_size_w = min_crop_size_w + self.max_crop_size_w = max_crop_size_w + + self.scale_inc_factor = scale_inc_factor + self.scale_ep_intervals = scale_ep_intervals + self.max_img_scales = max_img_scales + self.check_scale_div_factor = check_scale_div_factor + self.scale_inc = scale_inc + + if is_training: + self.img_batch_tuples = _image_batch_pairs( + crop_size_h=self.crop_size_h, + crop_size_w=self.crop_size_w, + batch_size_gpu0=self.batch_size_gpu0, + n_gpus=self.num_replicas, + max_scales=self.max_img_scales, + check_scale_div_factor=self.check_scale_div_factor, + min_crop_size_w=self.min_crop_size_w, + max_crop_size_w=self.max_crop_size_w, + min_crop_size_h=self.min_crop_size_h, + max_crop_size_h=self.max_crop_size_h, + ) + self.img_batch_tuples = [ + (h, w, self.batch_size_gpu0) for h, w, b in self.img_batch_tuples + ] + else: + self.img_batch_tuples = [ + (self.crop_size_h, self.crop_size_w, self.batch_size_gpu0) + ] + + def __iter__(self): + indices_rank_i = self.get_indices_rank_i() + + start_index = 0 + n_samples_rank_i = len(indices_rank_i) + while start_index < n_samples_rank_i: + crop_h, crop_w, batch_size = random.choice(self.img_batch_tuples) + + end_index = min(start_index + batch_size, n_samples_rank_i) + batch_ids = indices_rank_i[start_index:end_index] + n_batch_samples = len(batch_ids) + if n_batch_samples != batch_size: + batch_ids += indices_rank_i[: (batch_size - n_batch_samples)] + start_index += batch_size + + if len(batch_ids) > 0: + batch = [(crop_h, crop_w, b_id) for b_id in batch_ids] + yield batch + + def update_scales(self, epoch, is_master_node=False, *args, **kwargs): + pass + + def __repr__(self): + repr_str = "{}(".format(self.__class__.__name__) + repr_str += ( + "\n\t base_im_size=(h={}, w={}), " + "\n\t base_batch_size={} " + "\n\t scales={} " + "\n\t scale_inc={} " + "\n\t scale_inc_factor={} " + "\n\t ep_intervals={}".format( + self.crop_size_h, + self.crop_size_w, + self.batch_size_gpu0, + self.img_batch_tuples, + self.scale_inc, + self.scale_inc_factor, + self.scale_ep_intervals, + ) + ) + repr_str += self.extra_repr() + repr_str += "\n )" + return repr_str diff --git a/Adaptive Frequency Filters/data/sampler/utils.py b/Adaptive Frequency Filters/data/sampler/utils.py new file mode 100644 index 0000000..bdc9623 --- /dev/null +++ b/Adaptive Frequency Filters/data/sampler/utils.py @@ -0,0 +1,125 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from typing import Optional, List +import numpy as np + +from utils.math_utils import make_divisible + + +def _image_batch_pairs( + crop_size_w: int, + crop_size_h: int, + batch_size_gpu0: int, + n_gpus: int, + max_scales: Optional[float] = 5, + check_scale_div_factor: Optional[int] = 32, + min_crop_size_w: Optional[int] = 160, + max_crop_size_w: Optional[int] = 320, + min_crop_size_h: Optional[int] = 160, + max_crop_size_h: Optional[int] = 320, + *args, + **kwargs +) -> List: + """ + This function creates batch and image size pairs. For a given batch size and image size, different image sizes + are generated and batch size is adjusted so that GPU memory can be utilized efficiently. + + Args: + crop_size_w (int): Base Image width (e.g., 224) + crop_size_h (int): Base Image height (e.g., 224) + batch_size_gpu0 (int): Batch size on GPU 0 for base image + n_gpus (int): Number of available GPUs + max_scales (Optional[int]): Number of scales. How many image sizes that we want to generate between min and max scale factors. Default: 5 + check_scale_div_factor (Optional[int]): Check if image scales are divisible by this factor. Default: 32 + min_crop_size_w (Optional[int]): Min. crop size along width. Default: 160 + max_crop_size_w (Optional[int]): Max. crop size along width. Default: 320 + min_crop_size_h (Optional[int]): Min. crop size along height. Default: 160 + max_crop_size_h (Optional[int]): Max. crop size along height. Default: 320 + + Returns: + a sorted list of tuples. Each index is of the form (h, w, batch_size) + + """ + width_dims = list(np.linspace(min_crop_size_w, max_crop_size_w, max_scales)) + if crop_size_w not in width_dims: + width_dims.append(crop_size_w) + + height_dims = list(np.linspace(min_crop_size_h, max_crop_size_h, max_scales)) + if crop_size_h not in height_dims: + height_dims.append(crop_size_h) + + image_scales = set() + + for h, w in zip(height_dims, width_dims): + # ensure that sampled sizes are divisible by check_scale_div_factor + # This is important in some cases where input undergoes a fixed number of down-sampling stages + # for instance, in ImageNet training, CNNs usually have 5 downsampling stages, which downsamples the + # input image of resolution 224x224 to 7x7 size + h = make_divisible(h, check_scale_div_factor) + w = make_divisible(w, check_scale_div_factor) + image_scales.add((h, w)) + + image_scales = list(image_scales) + + img_batch_tuples = set() + n_elements = crop_size_w * crop_size_h * batch_size_gpu0 + for (crop_h, crop_y) in image_scales: + # compute the batch size for sampled image resolutions with respect to the base resolution + _bsz = max(1, int(round(n_elements / (crop_h * crop_y), 2))) + + img_batch_tuples.add((crop_h, crop_y, _bsz)) + + img_batch_tuples = list(img_batch_tuples) + return sorted(img_batch_tuples) + + +def make_video_pairs( + crop_size_h: int, + crop_size_w: int, + min_crop_size_h: int, + max_crop_size_h: int, + min_crop_size_w: int, + max_crop_size_w: int, + default_frames: int, + max_scales: Optional[int] = 5, + check_scale_div_factor: Optional[int] = 32, + *args, + **kwargs +) -> List: + """ + This function creates number of frames and spatial size pairs for videos. + + Args: + crop_size_h (int): Base Image height (e.g., 224) + crop_size_w (int): Base Image width (e.g., 224) + min_crop_size_w (int): Min. crop size along width. + max_crop_size_w (int): Max. crop size along width. + min_crop_size_h (int): Min. crop size along height. + max_crop_size_h (int): Max. crop size along height. + default_frames (int): Default number of frames per clip in a video. + max_scales (Optional[int]): Number of scales. Default: 5 + check_scale_div_factor (Optional[int]): Check if spatial scales are divisible by this factor. Default: 32 + Returns: + a sorted list of tuples. Each index is of the form (h, w, n_frames) + """ + + width_dims = list(np.linspace(min_crop_size_w, max_crop_size_w, max_scales)) + if crop_size_w not in width_dims: + width_dims.append(crop_size_w) + height_dims = list(np.linspace(min_crop_size_h, max_crop_size_h, max_scales)) + if crop_size_h not in height_dims: + height_dims.append(crop_size_h) + + # ensure that spatial dimensions are divisible by check_scale_div_factor + width_dims = [make_divisible(w, check_scale_div_factor) for w in width_dims] + height_dims = [make_divisible(h, check_scale_div_factor) for h in height_dims] + batch_pairs = set() + n_elements = crop_size_w * crop_size_h * default_frames + for (h, w) in zip(height_dims, width_dims): + n_frames = max(1, int(round(n_elements / (h * w), 2))) + batch_pairs.add((h, w, n_frames)) + return sorted(list(batch_pairs)) diff --git a/Adaptive Frequency Filters/data/sampler/variable_batch_sampler.py b/Adaptive Frequency Filters/data/sampler/variable_batch_sampler.py new file mode 100644 index 0000000..0886680 --- /dev/null +++ b/Adaptive Frequency Filters/data/sampler/variable_batch_sampler.py @@ -0,0 +1,422 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- +import copy +import random +import argparse +from typing import Optional +import numpy as np +import math + +from utils import logger +from common import DEFAULT_IMAGE_WIDTH, DEFAULT_IMAGE_HEIGHT + +from . import register_sampler, BaseSamplerDP, BaseSamplerDDP +from .utils import _image_batch_pairs + + +@register_sampler(name="variable_batch_sampler") +class VariableBatchSampler(BaseSamplerDP): + """ + `Variably-size multi-scale batch sampler ` for data parallel + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, n_data_samples=n_data_samples, is_training=is_training + ) + + crop_size_w: int = getattr( + opts, "sampler.vbs.crop_size_width", DEFAULT_IMAGE_WIDTH + ) + crop_size_h: int = getattr( + opts, "sampler.vbs.crop_size_height", DEFAULT_IMAGE_HEIGHT + ) + + min_crop_size_w: int = getattr(opts, "sampler.vbs.min_crop_size_width", 160) + max_crop_size_w: int = getattr(opts, "sampler.vbs.max_crop_size_width", 320) + + min_crop_size_h: int = getattr(opts, "sampler.vbs.min_crop_size_height", 160) + max_crop_size_h: int = getattr(opts, "sampler.vbs.max_crop_size_height", 320) + + scale_inc: bool = getattr(opts, "sampler.vbs.scale_inc", False) + scale_ep_intervals: list or int = getattr( + opts, "sampler.vbs.ep_intervals", [40] + ) + min_scale_inc_factor: float = getattr( + opts, "sampler.vbs.min_scale_inc_factor", 1.0 + ) + max_scale_inc_factor: float = getattr( + opts, "sampler.vbs.max_scale_inc_factor", 1.0 + ) + + check_scale_div_factor: int = getattr(opts, "sampler.vbs.check_scale", 32) + max_img_scales: int = getattr(opts, "sampler.vbs.max_n_scales", 10) + + if isinstance(scale_ep_intervals, int): + scale_ep_intervals = [scale_ep_intervals] + + self.min_crop_size_w = min_crop_size_w + self.max_crop_size_w = max_crop_size_w + self.min_crop_size_h = min_crop_size_h + self.max_crop_size_h = max_crop_size_h + + self.crop_size_w = crop_size_w + self.crop_size_h = crop_size_h + + self.min_scale_inc_factor = min_scale_inc_factor + self.max_scale_inc_factor = max_scale_inc_factor + self.scale_ep_intervals = scale_ep_intervals + + self.max_img_scales = max_img_scales + self.check_scale_div_factor = check_scale_div_factor + self.scale_inc = scale_inc + + if is_training: + self.img_batch_tuples = _image_batch_pairs( + crop_size_h=self.crop_size_h, + crop_size_w=self.crop_size_w, + batch_size_gpu0=self.batch_size_gpu0, + n_gpus=self.n_gpus, + max_scales=self.max_img_scales, + check_scale_div_factor=self.check_scale_div_factor, + min_crop_size_w=self.min_crop_size_w, + max_crop_size_w=self.max_crop_size_w, + min_crop_size_h=self.min_crop_size_h, + max_crop_size_h=self.max_crop_size_h, + ) + else: + self.img_batch_tuples = [(crop_size_h, crop_size_w, self.batch_size_gpu0)] + + def __iter__(self): + img_indices = self.get_indices() + start_index = 0 + n_samples = len(img_indices) + while start_index < n_samples: + crop_h, crop_w, batch_size = random.choice(self.img_batch_tuples) + + end_index = min(start_index + batch_size, n_samples) + batch_ids = img_indices[start_index:end_index] + n_batch_samples = len(batch_ids) + if len(batch_ids) != batch_size: + batch_ids += img_indices[: (batch_size - n_batch_samples)] + start_index += batch_size + + if len(batch_ids) > 0: + batch = [(crop_h, crop_w, b_id) for b_id in batch_ids] + yield batch + + def update_scales(self, epoch, is_master_node=False, *args, **kwargs): + if epoch in self.scale_ep_intervals and self.scale_inc: + self.min_crop_size_w += int( + self.min_crop_size_w * self.min_scale_inc_factor + ) + self.max_crop_size_w += int( + self.max_crop_size_w * self.max_scale_inc_factor + ) + + self.min_crop_size_h += int( + self.min_crop_size_h * self.min_scale_inc_factor + ) + self.max_crop_size_h += int( + self.max_crop_size_h * self.max_scale_inc_factor + ) + + self.img_batch_tuples = _image_batch_pairs( + crop_size_h=self.crop_size_h, + crop_size_w=self.crop_size_w, + batch_size_gpu0=self.batch_size_gpu0, + n_gpus=self.n_gpus, + max_scales=self.max_img_scales, + check_scale_div_factor=self.check_scale_div_factor, + min_crop_size_w=self.min_crop_size_w, + max_crop_size_w=self.max_crop_size_w, + min_crop_size_h=self.min_crop_size_h, + max_crop_size_h=self.max_crop_size_h, + ) + if is_master_node: + logger.log("Scales updated in {}".format(self.__class__.__name__)) + logger.log("New scales: {}".format(self.img_batch_tuples)) + + def __repr__(self): + repr_str = "{}(".format(self.__class__.__name__) + repr_str += ( + "\n\t base_im_size=(h={}, w={}), " + "\n\t base_batch_size={} " + "\n\t scales={} " + "\n\t scale_inc={} " + "\n\t min_scale_inc_factor={} " + "\n\t max_scale_inc_factor={} " + "\n\t ep_intervals={}".format( + self.crop_size_h, + self.crop_size_w, + self.batch_size_gpu0, + self.img_batch_tuples, + self.scale_inc, + self.min_scale_inc_factor, + self.max_scale_inc_factor, + self.scale_ep_intervals, + ) + ) + repr_str += self.extra_repr() + repr_str += "\n)" + return repr_str + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Variable batch sampler", + description="Arguments related to variable batch sampler", + ) + group.add_argument( + "--sampler.vbs.crop-size-width", + default=DEFAULT_IMAGE_WIDTH, + type=int, + help="Base crop size (along width) during training", + ) + group.add_argument( + "--sampler.vbs.crop-size-height", + default=DEFAULT_IMAGE_HEIGHT, + type=int, + help="Base crop size (along height) during training", + ) + + group.add_argument( + "--sampler.vbs.min-crop-size-width", + default=160, + type=int, + help="Min. crop size along width during training", + ) + group.add_argument( + "--sampler.vbs.max-crop-size-width", + default=320, + type=int, + help="Max. crop size along width during training", + ) + + group.add_argument( + "--sampler.vbs.min-crop-size-height", + default=160, + type=int, + help="Min. crop size along height during training", + ) + group.add_argument( + "--sampler.vbs.max-crop-size-height", + default=320, + type=int, + help="Max. crop size along height during training", + ) + group.add_argument( + "--sampler.vbs.max-n-scales", + default=5, + type=int, + help="Max. scales in variable batch sampler. For example, [0.25, 0.5, 0.75, 1, 1.25] ", + ) + group.add_argument( + "--sampler.vbs.check-scale", + default=32, + type=int, + help="Image scales should be divisible by this factor", + ) + group.add_argument( + "--sampler.vbs.ep-intervals", + default=[40], + type=int, + help="Epoch intervals at which scales are adjusted", + ) + group.add_argument( + "--sampler.vbs.min-scale-inc-factor", + default=1.0, + type=float, + help="Factor by which we should increase the minimum scale", + ) + group.add_argument( + "--sampler.vbs.max-scale-inc-factor", + default=1.0, + type=float, + help="Factor by which we should increase the maximum scale", + ) + group.add_argument( + "--sampler.vbs.scale-inc", + action="store_true", + help="Increase image scales during training", + ) + + return parser + + +@register_sampler(name="variable_batch_sampler_ddp") +class VariableBatchSamplerDDP(BaseSamplerDDP): + """ + `Variably-size multi-scale batch sampler ` for distributed + data parallel + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + """ + + :param opts: arguments + :param n_data_samples: number of data samples in the dataset + :param is_training: Training or evaluation mode (eval mode includes validation mode) + """ + super().__init__( + opts=opts, n_data_samples=n_data_samples, is_training=is_training + ) + crop_size_w: int = getattr( + opts, "sampler.vbs.crop_size_width", DEFAULT_IMAGE_WIDTH + ) + crop_size_h: int = getattr( + opts, "sampler.vbs.crop_size_height", DEFAULT_IMAGE_HEIGHT + ) + + min_crop_size_w: int = getattr(opts, "sampler.vbs.min_crop_size_width", 160) + max_crop_size_w: int = getattr(opts, "sampler.vbs.max_crop_size_width", 320) + + min_crop_size_h: int = getattr(opts, "sampler.vbs.min_crop_size_height", 160) + max_crop_size_h: int = getattr(opts, "sampler.vbs.max_crop_size_height", 320) + + scale_inc: bool = getattr(opts, "sampler.vbs.scale_inc", False) + scale_ep_intervals: list or int = getattr( + opts, "sampler.vbs.ep_intervals", [40] + ) + min_scale_inc_factor: float = getattr( + opts, "sampler.vbs.min_scale_inc_factor", 1.0 + ) + max_scale_inc_factor: float = getattr( + opts, "sampler.vbs.max_scale_inc_factor", 1.0 + ) + check_scale_div_factor: int = getattr(opts, "sampler.vbs.check_scale", 32) + + max_img_scales: int = getattr(opts, "sampler.vbs.max_n_scales", 10) + + self.crop_size_h = crop_size_h + self.crop_size_w = crop_size_w + self.min_crop_size_h = min_crop_size_h + self.max_crop_size_h = max_crop_size_h + self.min_crop_size_w = min_crop_size_w + self.max_crop_size_w = max_crop_size_w + + self.min_scale_inc_factor = min_scale_inc_factor + self.max_scale_inc_factor = max_scale_inc_factor + self.scale_ep_intervals = scale_ep_intervals + self.max_img_scales = max_img_scales + self.check_scale_div_factor = check_scale_div_factor + self.scale_inc = scale_inc + + if is_training: + self.img_batch_tuples = _image_batch_pairs( + crop_size_h=self.crop_size_h, + crop_size_w=self.crop_size_w, + batch_size_gpu0=self.batch_size_gpu0, + n_gpus=self.num_replicas, + max_scales=self.max_img_scales, + check_scale_div_factor=self.check_scale_div_factor, + min_crop_size_w=self.min_crop_size_w, + max_crop_size_w=self.max_crop_size_w, + min_crop_size_h=self.min_crop_size_h, + max_crop_size_h=self.max_crop_size_h, + ) + else: + self.img_batch_tuples = [ + (self.crop_size_h, self.crop_size_w, self.batch_size_gpu0) + ] + + def __iter__(self): + indices_rank_i = self.get_indices_rank_i() + start_index = 0 + n_samples_rank_i = len(indices_rank_i) + while start_index < n_samples_rank_i: + crop_h, crop_w, batch_size = random.choice(self.img_batch_tuples) + + end_index = min(start_index + batch_size, n_samples_rank_i) + batch_ids = indices_rank_i[start_index:end_index] + n_batch_samples = len(batch_ids) + if n_batch_samples != batch_size: + batch_ids += indices_rank_i[: (batch_size - n_batch_samples)] + start_index += batch_size + + if len(batch_ids) > 0: + batch = [(crop_h, crop_w, b_id) for b_id in batch_ids] + yield batch + + def update_scales(self, epoch, is_master_node=False, *args, **kwargs): + if (epoch in self.scale_ep_intervals) and self.scale_inc: # Training mode + self.min_crop_size_w += int( + self.min_crop_size_w * self.min_scale_inc_factor + ) + self.max_crop_size_w += int( + self.max_crop_size_w * self.max_scale_inc_factor + ) + + self.min_crop_size_h += int( + self.min_crop_size_h * self.min_scale_inc_factor + ) + self.max_crop_size_h += int( + self.max_crop_size_h * self.max_scale_inc_factor + ) + + self.img_batch_tuples = _image_batch_pairs( + crop_size_h=self.crop_size_h, + crop_size_w=self.crop_size_w, + batch_size_gpu0=self.batch_size_gpu0, + n_gpus=self.num_replicas, + max_scales=self.max_img_scales, + check_scale_div_factor=self.check_scale_div_factor, + min_crop_size_w=self.min_crop_size_w, + max_crop_size_w=self.max_crop_size_w, + min_crop_size_h=self.min_crop_size_h, + max_crop_size_h=self.max_crop_size_h, + ) + if is_master_node: + logger.log("Scales updated in {}".format(self.__class__.__name__)) + logger.log("New scales: {}".format(self.img_batch_tuples)) + + def __repr__(self): + repr_str = "{}(".format(self.__class__.__name__) + repr_str += ( + "\n\t base_im_size=(h={}, w={}), " + "\n\t base_batch_size={} " + "\n\t scales={} " + "\n\t scale_inc={} " + "\n\t min_scale_inc_factor={} " + "\n\t max_scale_inc_factor={} " + "\n\t ep_intervals={} ".format( + self.crop_size_h, + self.crop_size_w, + self.batch_size_gpu0, + self.img_batch_tuples, + self.scale_inc, + self.min_scale_inc_factor, + self.max_scale_inc_factor, + self.scale_ep_intervals, + ) + ) + repr_str += self.extra_repr() + repr_str += "\n )" + return repr_str diff --git a/Adaptive Frequency Filters/data/sampler/video_batch_sampler.py b/Adaptive Frequency Filters/data/sampler/video_batch_sampler.py new file mode 100644 index 0000000..c54b736 --- /dev/null +++ b/Adaptive Frequency Filters/data/sampler/video_batch_sampler.py @@ -0,0 +1,163 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import random +import argparse +from typing import Optional + +from . import register_sampler +from .batch_sampler import BatchSamplerDDP, BatchSampler + + +@register_sampler(name="video_batch_sampler") +class VideoBatchSampler(BatchSampler): + """ + Batch sampler for videos + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, n_data_samples=n_data_samples, is_training=is_training + ) + self.default_frames = getattr(opts, "sampler.bs.num_frames_per_clip", 8) + + self.clips_per_video = getattr(opts, "sampler.bs.clips_per_video", 1) + + def __iter__(self): + indices = self.get_indices() + + start_index = 0 + batch_size = self.batch_size_gpu0 + indices_len = len(indices) + while start_index < indices_len: + + end_index = min(start_index + batch_size, indices_len) + batch_ids = indices[start_index:end_index] + start_index += batch_size + + if len(batch_ids) > 0: + batch = [ + ( + self.crop_size_h, + self.crop_size_w, + b_id, + self.default_frames, + self.clips_per_video, + ) + for b_id in batch_ids + ] + yield batch + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Batch sampler for videos", + description="Arguments related to variable batch sampler", + ) + group.add_argument( + "--sampler.bs.num-frames-per-clip", + default=8, + type=int, + help="Number of frames per video clip", + ) + group.add_argument( + "--sampler.bs.clips-per-video", + default=1, + type=int, + help="Number of clips per video", + ) + return parser + + def __repr__(self): + repr_str = "{}(".format(self.__class__.__name__) + repr_str += "\n \t base_im_size=(h={}, w={})\n \t base_batch_size={}\n \t n_clips={}\n \tn_frames={}".format( + self.crop_size_h, + self.crop_size_w, + self.batch_size_gpu0, + self.clips_per_video, + self.default_frames, + ) + repr_str += self.extra_repr() + repr_str += "\n)" + return repr_str + + +@register_sampler(name="video_batch_sampler_ddp") +class VideoBatchSamplerDDP(BatchSamplerDDP): + """ + Batch sampler for videos (DDP) + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, n_data_samples=n_data_samples, is_training=is_training + ) + self.default_frames = getattr(opts, "sampler.bs.num_frames_per_clip", 8) + self.clips_per_video = getattr(opts, "sampler.bs.clips_per_video", 1) + + def __iter__(self): + indices_rank_i = self.get_indices_rank_i() + + start_index = 0 + batch_size = self.batch_size_gpu0 + indices_len = len(indices_rank_i) + while start_index < indices_len: + end_index = min(start_index + batch_size, indices_len) + batch_ids = indices_rank_i[start_index:end_index] + n_batch_samples = len(batch_ids) + if n_batch_samples != batch_size: + batch_ids += indices_rank_i[: (batch_size - n_batch_samples)] + start_index += batch_size + + if len(batch_ids) > 0: + batch = [ + ( + self.crop_size_h, + self.crop_size_w, + b_id, + self.default_frames, + self.clips_per_video, + ) + for b_id in batch_ids + ] + yield batch + + def __repr__(self): + repr_str = "{}(".format(self.__class__.__name__) + repr_str += "\n \t base_im_size=(h={}, w={})\n \t base_batch_size={}\n \t n_clips={}\n \tn_frames={}".format( + self.crop_size_h, + self.crop_size_w, + self.batch_size_gpu0, + self.clips_per_video, + self.default_frames, + ) + repr_str += self.extra_repr() + repr_str += "\n)" + return repr_str diff --git a/Adaptive Frequency Filters/data/sampler/video_variable_seq_sampler.py b/Adaptive Frequency Filters/data/sampler/video_variable_seq_sampler.py new file mode 100644 index 0000000..e603d9a --- /dev/null +++ b/Adaptive Frequency Filters/data/sampler/video_variable_seq_sampler.py @@ -0,0 +1,318 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import random +import argparse +from typing import Optional + +from utils import logger + +from .utils import make_video_pairs +from . import register_sampler +from .variable_batch_sampler import VariableBatchSampler, VariableBatchSamplerDDP + + +@register_sampler(name="video_variable_seq_sampler") +class VideoVariableSeqSampler(VariableBatchSampler): + """ + Extends `Variably-size multi-scale batch sampler ` for videos + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, n_data_samples=n_data_samples, is_training=is_training + ) + self.default_frames = getattr(opts, "sampler.vbs.num_frames_per_clip", 8) + + self.random_video_clips = ( + getattr(opts, "sampler.vbs.random_video_clips", False) + if is_training + else False + ) + self.min_clips_per_video = getattr(opts, "sampler.vbs.min_clips_per_video", 1) + self.max_clips_per_video = getattr(opts, "sampler.vbs.max_clips_per_video", 5) + self.clips_per_video = getattr(opts, "sampler.vbs.clips_per_video", 1) + if self.min_clips_per_video is None: + self.min_clips_per_video = 1 + + if is_training: + # override img_batch_tuples + self.img_batch_tuples = make_video_pairs( + crop_size_h=self.crop_size_h, + crop_size_w=self.crop_size_w, + min_crop_size_h=self.min_crop_size_h, + max_crop_size_h=self.max_crop_size_h, + min_crop_size_w=self.min_crop_size_w, + max_crop_size_w=self.max_crop_size_w, + max_scales=self.max_img_scales, + check_scale_div_factor=self.check_scale_div_factor, + default_frames=self.default_frames, + ) + else: + self.img_batch_tuples = [ + (self.crop_size_h, self.crop_size_w, self.default_frames) + ] + + def __repr__(self): + repr_str = "{}(".format(self.__class__.__name__) + + if self.shuffle: + n_clips_str = "(min={}, max={})".format( + self.min_clips_per_video, self.max_clips_per_video + ) + else: + n_clips_str = self.clips_per_video + + repr_str += ( + "\n\t base_im_size=(h={}, w={}), " + "\n\t base_batch_size={} " + "\n\t scales (Height x Width x N_frames)={} " + "\n\t scale_inc={} " + "\n\t min_scale_inc_factor={} " + "\n\t max_scale_inc_factor={} " + "\n\t ep_intervals={}" + "\n\t num_repeat={}" + "\n\t num_clips={}".format( + self.crop_size_h, + self.crop_size_w, + self.batch_size_gpu0, + self.img_batch_tuples, + self.scale_inc, + self.min_scale_inc_factor, + self.max_scale_inc_factor, + self.scale_ep_intervals, + self.num_repeats, + n_clips_str, + ) + ) + repr_str += self.extra_repr() + repr_str += "\n)" + return repr_str + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Variable sequence sampler for videos", + description="Arguments related to variable sequence sampler", + ) + group.add_argument( + "--sampler.vbs.num-frames-per-clip", + default=8, + type=int, + help="Default frames per video", + ) + + group.add_argument( + "--sampler.vbs.random-video-clips", + action="store_true", + help="Sample number of clips per video randomly during training between min and max values specified using " + "--dataset.kinetics.min-clips-per-video and --dataset.kinetics.max-clips-per-video arguments " + "respectively", + ) + group.add_argument( + "--sampler.vbs.min-clips-per-video", + type=int, + default=1, + help="Minimum number of clips per video. Used only for training", + ) + group.add_argument( + "--sampler.vbs.max-clips-per-video", + type=int, + default=5, + help="Maximum number of clips per video. Used only for training", + ) + group.add_argument( + "--sampler.vbs.clips-per-video", + type=int, + default=1, + help="Number of clips per video", + ) + group.add_argument( + "--sampler.vbs.min-frames-per-clip", + type=int, + default=None, + help="Minimum number of frames per clip", + ) + + return parser + + def __iter__(self): + indices = self.get_indices() + + start_index = 0 + indices_len = len(indices) + while start_index < indices_len: + if self.random_video_clips: + # randomly sample number of clips and adjust frames per clip + n_clips = max( + 1, + random.randint(self.min_clips_per_video, self.max_clips_per_video), + ) + batch_size = max( + self.batch_size_gpu0, + self.batch_size_gpu0 * (self.clips_per_video // n_clips), + ) + else: + n_clips = self.clips_per_video + batch_size = self.batch_size_gpu0 + + crop_h, crop_w, n_frames = random.choice(self.img_batch_tuples) + end_index = min(start_index + batch_size, indices_len) + batch_ids = indices[start_index:end_index] + n_batch_samples = len(batch_ids) + if len(batch_ids) != batch_size: + batch_ids += indices[: (batch_size - n_batch_samples)] + start_index += batch_size + + if len(batch_ids) > 0: + + batch = [ + (crop_h, crop_w, b_id, n_frames, n_clips) for b_id in batch_ids + ] + yield batch + + def update_scales(self, epoch, is_master_node=False, *args, **kwargs): + pass + + +@register_sampler(name="video_variable_seq_sampler_ddp") +class VideoVariableSeqSamplerDDP(VariableBatchSamplerDDP): + """ + Extends `Variably-size multi-scale batch sampler ` for videos + + Args: + opts: command line argument + n_data_samples (int): Number of samples in the dataset + is_training (Optional[bool]): Training or validation mode. Default: False + """ + + def __init__( + self, + opts, + n_data_samples: int, + is_training: Optional[bool] = False, + *args, + **kwargs + ) -> None: + super().__init__( + opts=opts, n_data_samples=n_data_samples, is_training=is_training + ) + self.default_frames = getattr(opts, "sampler.vbs.num_frames_per_clip", 8) + + self.random_video_clips = ( + getattr(opts, "sampler.vbs.random_video_clips", False) + if is_training + else False + ) + self.min_clips_per_video = getattr(opts, "sampler.vbs.min_clips_per_video", 1) + self.max_clips_per_video = getattr(opts, "sampler.vbs.max_clips_per_video", 5) + self.clips_per_video = getattr(opts, "sampler.vbs.clips_per_video", 1) + if self.min_clips_per_video is None: + self.min_clips_per_video = 1 + + if is_training: + # override img_batch_tuples + self.img_batch_tuples = make_video_pairs( + crop_size_h=self.crop_size_h, + crop_size_w=self.crop_size_w, + min_crop_size_h=self.min_crop_size_h, + max_crop_size_h=self.max_crop_size_h, + min_crop_size_w=self.min_crop_size_w, + max_crop_size_w=self.max_crop_size_w, + max_scales=self.max_img_scales, + check_scale_div_factor=self.check_scale_div_factor, + default_frames=self.default_frames, + ) + else: + self.img_batch_tuples = [ + (self.crop_size_h, self.crop_size_w, self.default_frames) + ] + + def __repr__(self): + repr_str = "{}(".format(self.__class__.__name__) + + if self.shuffle: + n_clips_str = "(min={}, max={})".format( + self.min_clips_per_video, self.max_clips_per_video + ) + else: + n_clips_str = self.clips_per_video + + repr_str += ( + "\n\t base_im_size=(h={}, w={}), " + "\n\t base_batch_size={} " + "\n\t scales (Height x Width x N_frames)={} " + "\n\t scale_inc={} " + "\n\t min_scale_inc_factor={} " + "\n\t max_scale_inc_factor={} " + "\n\t ep_intervals={}" + "\n\t num_repeat={}" + "\n\t num_clips={}".format( + self.crop_size_h, + self.crop_size_w, + self.batch_size_gpu0, + self.img_batch_tuples, + self.scale_inc, + self.min_scale_inc_factor, + self.max_scale_inc_factor, + self.scale_ep_intervals, + self.num_repeats, + n_clips_str, + ) + ) + repr_str += self.extra_repr() + repr_str += "\n)" + return repr_str + + def __iter__(self): + indices_rank_i = self.get_indices_rank_i() + + start_index = 0 + n_samples_rank_i = len(indices_rank_i) + while start_index < n_samples_rank_i: + if self.random_video_clips: + # randomly sample number of clips and adjust batch size + n_clips = max( + 1, + random.randint(self.min_clips_per_video, self.max_clips_per_video), + ) + batch_size = max( + self.batch_size_gpu0, + self.batch_size_gpu0 * (self.clips_per_video // n_clips), + ) + else: + n_clips = self.clips_per_video + batch_size = self.batch_size_gpu0 + + crop_h, crop_w, n_frames = random.choice(self.img_batch_tuples) + + end_index = min(start_index + batch_size, n_samples_rank_i) + batch_ids = indices_rank_i[start_index:end_index] + n_batch_samples = len(batch_ids) + if n_batch_samples != batch_size: + batch_ids += indices_rank_i[: (batch_size - n_batch_samples)] + start_index += batch_size + + if len(batch_ids) > 0: + batch = [ + (crop_h, crop_w, b_id, n_frames, n_clips) for b_id in batch_ids + ] + yield batch + + def update_scales(self, epoch, is_master_node=False, *args, **kwargs): + pass diff --git a/Adaptive Frequency Filters/data/text_tokenizer/__init__.py b/Adaptive Frequency Filters/data/text_tokenizer/__init__.py new file mode 100644 index 0000000..671ed4f --- /dev/null +++ b/Adaptive Frequency Filters/data/text_tokenizer/__init__.py @@ -0,0 +1,92 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import importlib +import argparse +from typing import Optional + +from utils import logger + +from .base_tokenizer import BaseTokenizer + + +TOKENIZER_REGISTRY = {} + + +def register_tokenizer(name): + # register the text_tokenizer class + def register_tokenizer_class(cls): + if name in TOKENIZER_REGISTRY: + raise ValueError( + "Cannot register duplicate text_tokenizer class ({})".format(name) + ) + + if not issubclass(cls, BaseTokenizer): + raise ValueError( + "Tokenizer ({}: {}) must extend BaseTokenizer".format( + name, cls.__name__ + ) + ) + + TOKENIZER_REGISTRY[name] = cls + return cls + + return register_tokenizer_class + + +def arguments_tokenizer(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + # add arguments for text_tokenizer + parser = BaseTokenizer.add_arguments(parser) + + # add augmentation specific arguments + for k, v in TOKENIZER_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + return parser + + +def supported_tokenizer_str(tokenizer_name: Optional[str] = None) -> None: + """Helper utility to print supported text_tokenizer names in case specified text_tokenizer + name is not part of the implemented tokenizers. + """ + supp_list = list(TOKENIZER_REGISTRY.keys()) + if tokenizer_name is None: + supp_str = "Tokenizer name can't be None. \n Supported tokenizers are:" + else: + supp_str = ( + "Tokenizer ({}) is not yet supported. \n Supported tokenizers are:".format( + tokenizer_name + ) + ) + for t_name in supp_list: + supp_str += "\n\t{}".format(t_name) + logger.error(supp_str + "\n") + + +def build_tokenizer(opts, *args, **kwargs) -> BaseTokenizer: + """Helper function to build the text_tokenizer""" + tokenizer_name = getattr(opts, "text_tokenizer.name", None) + if tokenizer_name is None: + supported_tokenizer_str(tokenizer_name) + + if tokenizer_name in list(TOKENIZER_REGISTRY.keys()): + return TOKENIZER_REGISTRY[tokenizer_name](opts, *args, **kwargs) + else: + supported_tokenizer_str(tokenizer_name) + + +# automatically import the tokenizers +tokenizer_dir = os.path.dirname(__file__) + +for file in os.listdir(tokenizer_dir): + path = os.path.join(tokenizer_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + tokenizer_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("data.text_tokenizer." + tokenizer_name) diff --git a/Adaptive Frequency Filters/data/text_tokenizer/base_tokenizer.py b/Adaptive Frequency Filters/data/text_tokenizer/base_tokenizer.py new file mode 100644 index 0000000..dd35586 --- /dev/null +++ b/Adaptive Frequency Filters/data/text_tokenizer/base_tokenizer.py @@ -0,0 +1,45 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import nn +from typing import Any +import argparse + + +class BaseTokenizer(nn.Module): + def __init__(self, opts, *args, **kwargs): + super().__init__() + self.opts = opts + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--text-tokenizer.name", + type=str, + default=None, + help="Name of the text tokenizer.", + ) + + return parser + + def get_vocab_size(self): + raise NotImplementedError + + def get_eot_token(self): + raise NotImplementedError + + def get_sot_token(self): + raise NotImplementedError + + def get_encodings(self): + raise NotImplementedError + + def forward(self, input_sentence: Any, *args, **kwargs) -> Any: + raise NotImplementedError diff --git a/Adaptive Frequency Filters/data/text_tokenizer/clip_tokenizer.py b/Adaptive Frequency Filters/data/text_tokenizer/clip_tokenizer.py new file mode 100644 index 0000000..f9d7d8b --- /dev/null +++ b/Adaptive Frequency Filters/data/text_tokenizer/clip_tokenizer.py @@ -0,0 +1,88 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse +import torch +from torch import Tensor +from typing import List +from torchtext.transforms import CLIPTokenizer + +from utils import logger +from utils.download_utils import get_local_path + +from . import BaseTokenizer, register_tokenizer + + +@register_tokenizer(name="clip") +class ClipTokenizer(BaseTokenizer): + def __init__(self, opts, *args, **kwargs): + merges_path = getattr(opts, "text_tokenizer.clip.merges_path", None) + if merges_path is None: + logger.error( + "Please specify BPE merge file using --text-tokenizer.clip.merges-path argument" + ) + + # DDP case is handled internally + merges_path = get_local_path(opts, path=merges_path) + + encoder_json_path = getattr(opts, "text_tokenizer.clip.encoder_json_path", None) + if encoder_json_path is None: + logger.error( + "Please specify Encoder JSON file using --text-tokenizer.clip.encoder-json-path argument" + ) + + encoder_json_path = get_local_path(opts, path=encoder_json_path) + + super().__init__(opts, *args, **kwargs) + self.tokenizer = CLIPTokenizer( + merges_path=merges_path, encoder_json_path=encoder_json_path + ) + # BPE encodings is a dict, where keys are tokens and values are token_ids + self.bpe_encodings = self.tokenizer.bpe.bpe_encoder_ + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--text-tokenizer.clip.merges-path", + type=str, + default=None, + help="Path to bpe merges file.", + ) + + group.add_argument( + "--text-tokenizer.clip.encoder-json-path", + type=str, + default=None, + help="Optional, path to BPE encoder json file. When specified, this is used to infer num_merges.", + ) + return parser + + def get_vocab_size(self): + return len(self.bpe_encodings) + + def get_encodings(self): + return self.bpe_encodings + + def get_eot_token(self): + return int(self.tokenizer("<|endoftext|>")[0]) + + def get_sot_token(self): + return int(self.tokenizer("<|startoftext|>")[0]) + + def forward(self, input_sentence: str, *args, **kwargs) -> Tensor: + # add start and eos tokens to input sentence + input_sentence = "<|startoftext|> " + input_sentence + " <|endoftext|>" + # tokenizer returns indices as a string + tokenized_sentence = self.tokenizer(input_sentence) + # convert string to int and then create a tensor + tokenized_sentence = torch.tensor( + [int(cap) for cap in tokenized_sentence], dtype=torch.long + ) + return tokenized_sentence diff --git a/Adaptive Frequency Filters/data/transforms/__init__.py b/Adaptive Frequency Filters/data/transforms/__init__.py new file mode 100644 index 0000000..55db80b --- /dev/null +++ b/Adaptive Frequency Filters/data/transforms/__init__.py @@ -0,0 +1,57 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import importlib +import argparse + +from .base_transforms import BaseTransformation + +SUPPORTED_AUG_CATEGORIES = [] +AUGMENTAION_REGISTRY = {} + + +def register_transformations(name, type): + def register_transformation_class(cls): + if name in AUGMENTAION_REGISTRY: + raise ValueError( + "Cannot register duplicate transformation class ({})".format(name) + ) + + if not issubclass(cls, BaseTransformation): + raise ValueError( + "Transformation ({}: {}) must extend BaseTransformation".format( + name, cls.__name__ + ) + ) + + AUGMENTAION_REGISTRY[name + "_" + type] = cls + return cls + + return register_transformation_class + + +def arguments_augmentation(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + + # add augmentation specific arguments + for k, v in AUGMENTAION_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + + return parser + + +# automatically import the augmentations +transform_dir = os.path.dirname(__file__) + +for file in os.listdir(transform_dir): + path = os.path.join(transform_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + transform_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("data.transforms." + transform_name) diff --git a/Adaptive Frequency Filters/data/transforms/base_transforms.py b/Adaptive Frequency Filters/data/transforms/base_transforms.py new file mode 100644 index 0000000..4f862b2 --- /dev/null +++ b/Adaptive Frequency Filters/data/transforms/base_transforms.py @@ -0,0 +1,27 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse +from typing import Dict + + +class BaseTransformation(object): + """ + Base class for augmentation methods + """ + + def __init__(self, opts, *args, **kwargs) -> None: + self.opts = opts + + def __call__(self, data: Dict) -> Dict: + raise NotImplementedError + + def __repr__(self) -> str: + return "{}()".format(self.__class__.__name__) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + return parser diff --git a/Adaptive Frequency Filters/data/transforms/image_opencv.py b/Adaptive Frequency Filters/data/transforms/image_opencv.py new file mode 100644 index 0000000..3bc1bf0 --- /dev/null +++ b/Adaptive Frequency Filters/data/transforms/image_opencv.py @@ -0,0 +1,1761 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import cv2 +from typing import Optional +import numpy as np +import random +import torch +import math +import argparse +from typing import Sequence, Dict, Any, Union, Tuple + +from utils import logger + +from .utils import jaccard_numpy +from . import register_transformations, BaseTransformation + +# This file is for compatibility with affnet_v0.1. In future, we won't maintain it. Please use functions from +# image_pil.py + + +_str_to_cv2_interpolation = { + "nearest": cv2.INTER_NEAREST, + "bilinear": cv2.INTER_LINEAR, + "cubic": cv2.INTER_CUBIC, +} + +_cv2_to_str_interpolation = { + cv2.INTER_NEAREST: "nearest", + cv2.INTER_LINEAR: "bilinear", + cv2.INTER_CUBIC: "cubic", +} + +_str_to_cv2_pad = { + "constant": cv2.BORDER_CONSTANT, + "edge": cv2.BORDER_REPLICATE, + "reflect": cv2.BORDER_REFLECT_101, + "symmetric": cv2.BORDER_REFLECT, +} + + +def _cv2_interpolation(interpolation): + if interpolation not in _str_to_cv2_interpolation: + interpolate_modes = list(_str_to_cv2_interpolation.keys()) + inter_str = "Supported interpolation modes are:" + for i, j in enumerate(interpolate_modes): + inter_str += "\n\t{}: {}".format(i, j) + logger.error(inter_str) + return _str_to_cv2_interpolation[interpolation] + + +def _cv2_padding(pad_mode): + if pad_mode not in _str_to_cv2_pad: + pad_modes = list(_str_to_cv2_pad.keys()) + pad_mode_str = "Supported padding modes are:" + for i, j in enumerate(pad_modes): + pad_mode_str += "\n\t{}: {}".format(i, j) + logger.error(pad_mode_str) + return _str_to_cv2_pad[pad_mode] + + +def _crop_fn(data: Dict, i: int, j: int, h: int, w: int): + img = data["image"] + crop_image = img[i : i + h, j : j + w] + data["image"] = crop_image + + if "mask" in data: + mask = data.pop("mask") + crop_mask = mask[i : i + h, j : j + w] + data["mask"] = crop_mask + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + + area_before_cropping = (boxes[..., 2] - boxes[..., 0]) * ( + boxes[..., 3] - boxes[..., 1] + ) + + boxes[..., 0::2] = np.clip(boxes[..., 0::2] - j, a_min=0, a_max=j + w) + boxes[..., 1::2] = np.clip(boxes[..., 1::2] - i, a_min=0, a_max=i + h) + + area_after_cropping = (boxes[..., 2] - boxes[..., 0]) * ( + boxes[..., 3] - boxes[..., 1] + ) + area_ratio = area_after_cropping / (area_before_cropping + 1) + + # keep the boxes whose area is atleast 20% of the area before cropping + keep = area_ratio >= 0.2 + + box_labels = data.pop("box_labels") + + data["box_coordinates"] = boxes[keep] + data["box_labels"] = box_labels[keep] + + if "instance_mask" in data: + assert "instance_coords" in data + + instance_masks = data.pop("instance_mask") + data["instance_mask"] = instance_masks[i : i + h, j : j + w] + + instance_coords = data.pop("instance_coords") + instance_coords[..., 0::2] = np.clip( + instance_coords[..., 0::2] - j, a_min=0, a_max=j + w + ) + instance_coords[..., 1::2] = np.clip( + instance_coords[..., 1::2] - i, a_min=0, a_max=i + h + ) + data["instance_coords"] = instance_coords + + return data + + +def _resize_fn( + data: Dict, size: Union[Sequence, int], interpolation: Optional[str] = "bilinear" +): + img = data["image"] + h, w = img.shape[:2] + + if isinstance(size, Sequence) and len(size) == 2: + size_h, size_w = size[0], size[1] + elif isinstance(size, int): + if (w <= h and w == size) or (h <= w and h == size): + return data + + if w < h: + size_h = int(size * h / w) + + size_w = size + else: + size_w = int(size * w / h) + size_h = size + else: + raise TypeError( + "Supported size args are int or tuple of length 2. Got inappropriate size arg: {}".format( + size + ) + ) + if isinstance(interpolation, str): + interpolation = _str_to_cv2_interpolation[interpolation] + img = cv2.resize(img, dsize=(size_w, size_h), interpolation=interpolation) + data["image"] = img + + if "mask" in data: + mask = data.pop("mask") + resized_mask = cv2.resize( + mask, dsize=(size_w, size_h), interpolation=cv2.INTER_NEAREST + ) + # this occurs when input is (H, W, 1) + if len(resized_mask.shape) != len(mask.shape): + resized_mask = resized_mask[..., None] + + data["mask"] = resized_mask + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + boxes[:, 0::2] *= 1.0 * size_w / w + boxes[:, 1::2] *= 1.0 * size_h / h + data["box_coordinates"] = boxes + + if "instance_mask" in data: + assert "instance_coords" in data + + instance_masks = data.pop("instance_mask") + + resized_instance_masks = cv2.resize( + instance_masks, dsize=(size_w, size_h), interpolation=cv2.INTER_NEAREST + ) + if len(instance_masks.shape) != len(resized_instance_masks.shape): + resized_instance_masks = resized_instance_masks[..., None] + data["instance_mask"] = resized_instance_masks + + instance_coords = data.pop("instance_coords") + instance_coords = instance_coords.astype(np.float) + instance_coords[..., 0::2] *= 1.0 * size_w / w + instance_coords[..., 1::2] *= 1.0 * size_h / h + data["instance_coords"] = instance_coords + + return data + + +def setup_size(size: Any, error_msg="Need a tuple of length 2"): + if isinstance(size, int): + return size, size + + if isinstance(size, (list, tuple)) and len(size) == 1: + return size[0], size[0] + + if len(size) != 2: + raise ValueError(error_msg) + + return size + + +@register_transformations(name="random_gamma_correction", type="image") +class RandomGammaCorrection(BaseTransformation): + def __init__(self, opts): + gamma_range = getattr( + opts, "image_augmentation.random_gamma_correction.gamma", (0.25, 1.75) + ) + p = getattr(opts, "image_augmentation.random_gamma_correction.p", 0.5) + super(RandomGammaCorrection, self).__init__(opts=opts) + self.gamma = setup_size(gamma_range) + self.p = p + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.random-gamma-correction.enable", + action="store_true", + help="use gamma correction", + ) + group.add_argument( + "--image-augmentation.random-gamma-correction.gamma", + type=float or tuple, + default=(0.5, 1.5), + help="Gamma range", + ) + group.add_argument( + "--image-augmentation.random-gamma-correction.p", + type=float, + default=0.5, + help="Probability that {} will be applied".format(cls.__name__), + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if random.random() <= self.p: + img = data["image"] + gamma = random.uniform(self.gamma[0], self.gamma[1]) + table = np.array( + [((i / 255.0) ** gamma) * 255 for i in np.arange(0, 256)] + ).astype("uint8") + img = cv2.LUT(img, table) + data["image"] = img + return data + + def __repr__(self): + return "{}(gamma={}, p={})".format(self.__class__.__name__, self.gamma, self.p) + + +@register_transformations(name="random_resize", type="image") +class RandomResize(BaseTransformation): + def __init__(self, opts): + min_size = getattr(opts, "image_augmentation.random_resize.min_size", 256) + max_size = getattr(opts, "image_augmentation.random_resize.max_size", 1024) + interpolation = getattr( + opts, "image-augmentation.random_resize.interpolation", "bilinear" + ) + super(RandomResize, self).__init__(opts=opts) + self.min_size = min_size + self.max_size = max_size + self.interpolation = _cv2_interpolation(interpolation=interpolation) + + def __call__(self, data: Dict) -> Dict: + random_size = random.randint(self.min_size, self.max_size) + return _resize_fn(data, size=random_size, interpolation=self.interpolation) + + def __repr__(self): + return "{}(min_size={}, max_size={}, interpolation={})".format( + self.__class__.__name__, + self.min_size, + self.max_size, + _cv2_to_str_interpolation[self.interpolation], + ) + + +@register_transformations(name="random_zoom_out", type="image") +class RandomZoomOut(BaseTransformation): + def __init__(self, opts, size: Optional[Sequence or int] = None): + side_range = getattr( + opts, "image_augmentation.random_zoom_out.side_range", [1, 4] + ) + p = getattr(opts, "image_augmentation.random_zoom_out.p", 0.5) + super(RandomZoomOut, self).__init__(opts=opts) + self.fill = 0.5 + self.side_range = side_range + self.p = p + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-zoom-out.enable", + action="store_true", + help="Use random scale", + ) + group.add_argument( + "--image-augmentation.random-zoom-out.side-range", + type=list or tuple, + default=[1, 4], + help="Side range", + ) + group.add_argument( + "--image-augmentation.random-zoom-out.p", + type=float, + default=0.5, + help="Probability of applying RandomZoomOut transformation", + ) + return parser + + def zoom_out( + self, image: np.ndarray, boxes: Optional[np.ndarray] = None + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + height, width, depth = image.shape + ratio = random.uniform(self.side_range[0], self.side_range[1]) + left = int(random.uniform(0, width * ratio - width)) + top = int(random.uniform(0, height * ratio - height)) + + expand_image = ( + np.ones((int(height * ratio), int(width * ratio), depth), dtype=image.dtype) + * self.fill + ) + expand_image[top : top + height, left : left + width] = image + + expand_boxes = None + if boxes is not None: + expand_boxes = boxes.copy() + expand_boxes[:, :2] += (left, top) + expand_boxes[:, 2:] += (left, top) + + return expand_image, expand_boxes + + def __call__(self, data: Dict) -> Dict: + if random.random() > self.p: + return data + img = data["image"] + + boxes = None + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + + img, boxes = self.zoom_out(image=img, boxes=boxes) + + data["image"] = img + data["box_coordinates"] = boxes + + return data + + def __repr__(self): + return "{}(min_scale={}, max_scale={}, interpolation={})".format( + self.__class__.__name__, + self.min_scale, + self.max_scale, + _cv2_to_str_interpolation[self.interpolation], + ) + + +@register_transformations(name="random_scale", type="image") +class RandomScale(BaseTransformation): + def __init__(self, opts, size: Optional[Sequence or int] = None): + min_scale = getattr(opts, "image_augmentation.random_scale.min_scale", 0.5) + max_scale = getattr(opts, "image_augmentation.random_scale.max_scale", 2.0) + interpolation = getattr( + opts, "image_augmentation.random_scale.interpolation", "bilinear" + ) + super(RandomScale, self).__init__(opts=opts) + self.min_scale = min_scale + self.max_scale = max_scale + self.interpolation = _cv2_interpolation(interpolation) + self.size = None + if size is not None: + self.size = setup_size(size) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-scale.enable", + action="store_true", + help="Use random scale", + ) + group.add_argument( + "--image-augmentation.random-scale.min-scale", + type=float, + default=0.5, + help="Min scale", + ) + group.add_argument( + "--image-augmentation.random-scale.max-scale", + type=float, + default=2.0, + help="Max scale", + ) + group.add_argument( + "--image-augmentation.random-scale.interpolation", + type=str, + default="bilinear", + help="Interpolation method", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + scale = random.uniform(self.min_scale, self.max_scale) + + img = data["image"] + if self.size is None: + height, width = img.shape[:2] + else: + height, width = self.size + target_height, target_width = int(height * scale), int(width * scale) + img = cv2.resize( + img, dsize=(target_width, target_height), interpolation=self.interpolation + ) + data["image"] = img + + if "mask" in data: + mask = data.pop("mask") + mask = cv2.resize( + mask, + dsize=(target_width, target_height), + interpolation=cv2.INTER_NEAREST, + ) + data["mask"] = mask + return data + + def __repr__(self): + return "{}(min_scale={}, max_scale={}, interpolation={})".format( + self.__class__.__name__, + self.min_scale, + self.max_scale, + _cv2_to_str_interpolation[self.interpolation], + ) + + +@register_transformations(name="random_resized_crop", type="image") +class RandomResizedCrop(BaseTransformation): + """ + Adapted from Pytorch Torchvision + """ + + def __init__(self, opts, size: tuple or int): + + interpolation = getattr( + opts, "image_augmentation.random_resized_crop.interpolation", "bilinear" + ) + scale = getattr( + opts, "image_augmentation.random_resized_crop.scale", (0.08, 1.0) + ) + ratio = getattr( + opts, + "image_augmentation.random_resized_crop.aspect_ratio", + (3.0 / 4.0, 4.0 / 3.0), + ) + + if not isinstance(scale, Sequence) or ( + isinstance(scale, Sequence) + and len(scale) != 2 + and 0.0 <= scale[0] < scale[1] + ): + logger.error( + "--image-augmentation.random-resized-crop.scale should be a tuple of length 2 " + "such that 0.0 <= scale[0] < scale[1]. Got: {}".format(scale) + ) + + if not isinstance(ratio, Sequence) or ( + isinstance(ratio, Sequence) + and len(ratio) != 2 + and 0.0 < ratio[0] < ratio[1] + ): + logger.error( + "--image-augmentation.random-resized-crop.aspect-ratio should be a tuple of length 2 " + "such that 0.0 < ratio[0] < ratio[1]. Got: {}".format(ratio) + ) + + ratio = (round(ratio[0], 3), round(ratio[1], 3)) + + super(RandomResizedCrop, self).__init__(opts=opts) + + self.scale = scale + self.size = setup_size(size=size) + + self.interpolation = _cv2_interpolation(interpolation) + self.ratio = ratio + + def get_params(self, height: int, width: int) -> (int, int, int, int): + area = height * width + for _ in range(10): + target_area = random.uniform(*self.scale) * area + log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = (1.0 * width) / height + if in_ratio < min(self.ratio): + w = width + h = int(round(w / min(self.ratio))) + elif in_ratio > max(self.ratio): + h = height + w = int(round(h * max(self.ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + def __call__(self, data: Dict) -> Dict: + img = data["image"] + height, width = img.shape[:2] + + i, j, h, w = self.get_params(height=height, width=width) + data = _crop_fn(data=data, i=i, j=j, h=h, w=w) + return _resize_fn(data=data, size=self.size, interpolation=self.interpolation) + + def __repr__(self): + return "{}(scale={}, ratio={}, interpolation={})".format( + self.__class__.__name__, + self.scale, + self.ratio, + _cv2_to_str_interpolation[self.interpolation], + ) + + +@register_transformations(name="random_crop", type="image") +class RandomCrop(BaseTransformation): + """ + Randomly crop the image to a given size + """ + + def __init__(self, opts, size: Sequence or int): + super(RandomCrop, self).__init__(opts=opts) + self.height, self.width = setup_size(size=size) + self.opts = opts + self.fill_mask = getattr(opts, "image_augmentation.random_crop.mask_fill", 255) + is_padding = not getattr( + opts, "image_augmentation.random_crop.resize_if_needed", False + ) + self.inp_process_fn = ( + self.pad_if_needed if not is_padding else self.resize_if_needed + ) + + @staticmethod + def get_params(img_h, img_w, target_h, target_w): + if img_w == target_w and img_h == target_h: + return 0, 0, img_h, img_w + i = random.randint(0, img_h - target_h) + j = random.randint(0, img_w - target_w) + return i, j, target_h, target_w + + @staticmethod + def get_params_from_box(boxes, img_h, img_w): + # x, y, w, h + offset = random.randint(20, 50) + start_x = max(0, int(round(np.min(boxes[..., 0]))) - offset) + start_y = max(0, int(round(np.min(boxes[..., 1]))) - offset) + end_x = min(int(round(np.max(boxes[..., 2]))) + offset, img_w) + end_y = min(int(round(np.max(boxes[..., 3]))) + offset, img_h) + + return start_y, start_x, end_y - start_y, end_x - start_x + + def pad_if_needed(self, data: Dict) -> Dict: + img = data["image"] + + h, w, channels = img.shape + pad_h = self.height - h if h < self.height else 0 + pad_w = self.width - w if w < self.width else 0 + + # padding format is (top, bottom, left, right) + img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0) + data["image"] = img + + if "mask" in data: + mask = data.pop("mask") + mask = cv2.copyMakeBorder( + mask, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.fill_mask + ) + data["mask"] = mask + return data + + def resize_if_needed(self, data: Dict) -> Dict: + img = data["image"] + + h, w, channels = img.shape + new_size = min(h + max(0, self.height - h), w + max(0, self.width - w)) + # resize while maintaining the aspect ratio + return _resize_fn(data, size=new_size, interpolation="bilinear") + + def __call__(self, data: Dict) -> Dict: + # box_info + if "box_coordinates" in data: + boxes = data.get("box_coordinates") + # crop the relevant area + image_h, image_w = data["image"].shape[:2] + box_i, box_j, box_h, box_w = self.get_params_from_box( + boxes, image_h, image_w + ) + data = _crop_fn(data, i=box_i, j=box_j, h=box_h, w=box_w) + + data = self.inp_process_fn(data) + img_h, img_w = data["image"].shape[:2] + i, j, h, w = self.get_params( + img_h=img_h, img_w=img_w, target_h=self.height, target_w=self.width + ) + data = _crop_fn(data=data, i=i, j=j, h=h, w=w) + + return data + + def __repr__(self): + return "{}(size=(h={}, w={}))".format( + self.__class__.__name__, self.height, self.width + ) + + +@register_transformations(name="random_flip", type="image") +class RandomFlip(BaseTransformation): + def __init__(self, opts): + super(RandomFlip, self).__init__(opts=opts) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.random-flip.enable", + action="store_true", + help="use random flipping", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + flip_choice = random.choices([0, 1, 2])[0] + if flip_choice in [0, 1]: # 1 - Horizontal, 0 - vertical + img = data["image"] + img = cv2.flip(img, flip_choice) + data["image"] = img + + if "mask" in data: + mask = data.pop("mask") + mask = cv2.flip(mask, flip_choice) + data["mask"] = mask + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + if flip_choice == 0: + height = img.shape[0] + boxes[:, 1::2] = height - boxes[:, 3::-2] + elif flip_choice == 1: + width = img.shape[1] + boxes[:, 0::2] = width - boxes[:, 2::-2] + + data["box_coordinates"] = boxes + + return data + + +@register_transformations(name="random_horizontal_flip", type="image") +class RandomHorizontalFlip(BaseTransformation): + def __init__(self, opts): + p = getattr(opts, "image_augmentation.random_horizontal_flip.p", 0.5) + super(RandomHorizontalFlip, self).__init__(opts=opts) + self.p = p + + def __call__(self, data: Dict) -> Dict: + + if random.random() <= self.p: + img = data["image"] + width = img.shape[1] + data["image"] = img[:, ::-1, ...] + + if "mask" in data: + mask = data.pop("mask") + mask = mask[:, ::-1, ...] + data["mask"] = mask + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + boxes[..., 0::2] = width - boxes[..., 2::-2] + data["box_coordinates"] = boxes + + if "instance_mask" in data: + assert "instance_coords" in data + + instance_coords = data.pop("instance_coords") + instance_coords[..., 0::2] = width - instance_coords[..., 2::-2] + data["instance_coords"] = instance_coords + + instance_masks = data.pop("instance_mask") + instance_masks = instance_masks[:, ::-1, ...] + data["instance_mask"] = instance_masks + + return data + + def __repr__(self): + return "{}(p={})".format(self.__class__.__name__, self.p) + + +@register_transformations(name="instance_processor", type="image") +class InstanceProcessor(BaseTransformation): + def __init__( + self, + opts, + instance_size: Optional[Union[int, Tuple[int, ...]]] = 16, + *args, + **kwargs + ): + super(InstanceProcessor, self).__init__(opts=opts) + self.instance_size = setup_size(instance_size) + + def __call__(self, data: Dict) -> Dict: + + if "instance_mask" in data: + assert "instance_coords" in data + instance_masks = data.pop("instance_mask") + instance_coords = data.pop("instance_coords") + instance_coords = instance_coords.astype(np.int) + + valid_boxes = (instance_coords[..., 3] > instance_coords[..., 1]) & ( + instance_coords[..., 2] > instance_coords[..., 0] + ) + instance_masks = instance_masks[..., valid_boxes] + instance_coords = instance_coords[valid_boxes] + + num_instances = instance_masks.shape[-1] + + resized_instances = [] + for i in range(num_instances): + instance_m = instance_masks[..., i] + box_coords = instance_coords[i] + instance_m = instance_m[ + box_coords[1] : box_coords[3], box_coords[0] : box_coords[2] + ] + instance_m = cv2.resize( + instance_m, + dsize=self.instance_size, + interpolation=cv2.INTER_NEAREST, + ) + resized_instances.append(instance_m) + + if len(resized_instances) == 0: + resized_instances = np.zeros( + shape=(self.instance_size[0], self.instance_size[1], 1), + dtype=np.uint8, + ) + instance_coords = np.array( + [[0, 0, self.instance_size[0], self.instance_size[1]]] + ) + else: + resized_instances = np.stack(resized_instances, axis=-1) + + data["instance_mask"] = resized_instances + data["instance_coords"] = instance_coords.astype(np.float) + return data + + +@register_transformations(name="random_vertical_flip", type="image") +class RandomVerticalFlip(BaseTransformation): + def __init__(self, opts): + p = getattr(opts, "image_augmentation.random_vertical_flip.p", 0.5) + super(RandomVerticalFlip, self).__init__(opts=opts) + self.p = p + + def __call__(self, data: Dict) -> Dict: + if random.random() <= self.p: + img = data["image"] + img = cv2.flip(img, 0) + data["image"] = img + + if "mask" in data: + mask = data.pop("mask") + mask = cv2.flip(mask, 0) + data["mask"] = mask + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + height = img.shape[0] + boxes[:, 1::2] = height - boxes[:, 3::-2] + + data["box_coordinates"] = boxes + + return data + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-vertical-flip.enable", + action="store_true", + help="use random vertical flipping", + ) + group.add_argument( + "--image-augmentation.random-vertical-flip.p", + type=float, + default=0.5, + help="Probability for random vertical flip", + ) + return parser + + def __repr__(self): + return "{}(p={})".format(self.__class__.__name__, self.p) + + +@register_transformations(name="random_rotation", type="image") +class RandomRotate(BaseTransformation): + def __init__(self, opts): + angle = getattr(opts, "image_augmentation.random_rotate.angle", 10.0) + fill = getattr(opts, "image_augmentation.random_rotate.mask_fill", 255) + interpolation = getattr( + opts, "image_augmentation.random_rotate.interpolation", "bilinear" + ) + p = getattr(opts, "image_augmentation.random_rotate.p", 0.5) + super(RandomRotate, self).__init__(opts=opts) + self.angle = angle + self.fill = fill + self.p = p + self.interpolation = _cv2_interpolation(interpolation) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + # These two arguments are CV2-specific + group.add_argument( + "--image-augmentation.random-rotate.interpolation", + type=str, + default="bilinear", + help="Interpolation method", + ) + group.add_argument( + "--image-augmentation.random-rotate.p", + type=float, + default=0.5, + help="Probability that {} will be applied".format(cls.__name__), + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if random.random() <= self.p: + img = data["image"] + height, width = img.shape[:2] + + random_angle = random.uniform(-self.angle, self.angle) + rotation_mat = cv2.getRotationMatrix2D( + center=(width / 2, height / 2), angle=random_angle, scale=1 + ) + + img_rotated = cv2.warpAffine( + src=img, + M=rotation_mat, + dsize=(width, height), + flags=self.interpolation, + borderValue=0, + ) + data["image"] = img_rotated + + if "mask" in data: + mask = data.pop("mask") + mask_rotated = cv2.warpAffine( + src=mask, + M=rotation_mat, + dsize=(width, height), + flags=cv2.INTER_NEAREST, + borderValue=self.fill, + ) + data["mask"] = mask_rotated + + if "box_coordinates" in data: + raise NotImplementedError( + "RandomRotate is not implemented for box coordinates" + ) + + return data + + def __repr__(self): + return "{}(angle={}, interpolation={}, p={})".format( + self.__class__.__name__, + self.angle, + _cv2_to_str_interpolation[self.interpolation], + self.p, + ) + + +BLUR_METHODS = ["gauss", "median", "average", "none", "any"] + + +@register_transformations(name="random_blur", type="image") +class RandomBlur(BaseTransformation): + def __init__(self, opts): + kernel_range = getattr( + opts, "image_augmentation.random_blur.kernel_size", [3, 7] + ) + blur_type = getattr(opts, "image_augmentation.random_blur.kernel_type", "any") + p = getattr(opts, "image_augmentation.random_blur.p", 0.5) + super(RandomBlur, self).__init__(opts=opts) + self.kernel_range = setup_size(kernel_range) + assert 1 <= self.kernel_range[0] <= self.kernel_range[1], "Got: {}".format( + self.kernel_range + ) + self.blur_type = blur_type + self.p = p + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.random-blur.enable", + action="store_true", + help="use random blurring", + ) + + group.add_argument( + "--image-augmentation.random-blur.kernel-size", + type=tuple or int or list, + default=[3, 7], + help="Randomly sample the kernel size from the given range", + ) + group.add_argument( + "--image-augmentation.random-blur.kernel-type", + type=str, + choices=BLUR_METHODS, + default=255, + help="Value used to fill the area after rotation", + ) + group.add_argument( + "--image-augmentation.random-blur.p", + type=float, + default=0.5, + help="Probability that {} will be applied".format(cls.__name__), + ) + return parser + + def blur_median(self, img: np.ndarray, ksize_x: int, ksize_y: int) -> np.ndarray: + ksize = ksize_x if random.random() < 0.5 else ksize_y + img = cv2.medianBlur(src=img, ksize=ksize) + return img + + def blur_avg(self, img: np.ndarray, ksize_x: int, ksize_y: int) -> np.ndarray: + return cv2.blur(src=img, ksize=(ksize_x, ksize_y)) + + def blur_gauss(self, img: np.ndarray, ksize_x: int, ksize_y: int) -> np.ndarray: + return cv2.GaussianBlur(src=img, ksize=(ksize_x, ksize_y), sigmaX=0) + + def blur_any(self, img: np.ndarray, ksize_x: int, ksize_y: int) -> np.ndarray: + blur_method = random.choice(BLUR_METHODS[:-1]) + if blur_method == "gauss": + img = self.blur_gauss(img=img, ksize_x=ksize_x, ksize_y=ksize_y) + elif blur_method == "median": + img = self.blur_median(img=img, ksize_x=ksize_x, ksize_y=ksize_y) + elif blur_method == "average": + img = self.blur_avg(img=img, ksize_x=ksize_x, ksize_y=ksize_y) + return img + + def __call__(self, data: Dict) -> Dict: + if self.blur_type == "none": + return data + + ksize_x = random.randint(self.kernel_range[0], self.kernel_range[1]) + ksize_y = random.randint(self.kernel_range[0], self.kernel_range[1]) + ksize_x = (ksize_x // 2) * 2 + 1 + ksize_y = (ksize_y // 2) * 2 + 1 + + img = data["image"] + + if self.blur_type == "any": + img = self.blur_any(img, ksize_x=ksize_x, ksize_y=ksize_y) + elif self.blur_type == "gaussian" and random.random() <= self.p: + img = self.blur_gauss(img=img, ksize_x=ksize_x, ksize_y=ksize_y) + elif self.blur_type == "median" and random.random() <= self.p: + img = self.blur_median(img=img, ksize_x=ksize_x, ksize_y=ksize_y) + elif self.blur_type == "average" and random.random() <= self.p: + img = self.blur_avg(img=img, ksize_x=ksize_x, ksize_y=ksize_y) + + data["image"] = img + return data + + def __repr__(self): + if self.blur_type == "any": + blur_type = ["gaussian", "median", "average"] + else: + blur_type = self.blur_type + return "{}(blur_type={}, kernel_range={})".format( + self.__class__.__name__, blur_type, self.kernel_range + ) + + +@register_transformations(name="random_translate", type="image") +class RandomTranslate(BaseTransformation): + def __init__(self, opts): + translate_factor = getattr( + opts, "image_augmentation.random_translate.factor", 0.2 + ) + assert 0 < translate_factor < 0.5, "Factor should be between 0 and 0.5" + super(RandomTranslate, self).__init__(opts=opts) + + self.translation_factor = translate_factor + + def __call__(self, data: Dict) -> Dict: + img = data["image"] + + height, width = img.shape[:2] + th = int(math.ceil(random.uniform(0, self.translation_factor) * height)) + tw = int(math.ceil(random.uniform(0, self.translation_factor) * width)) + img_translated = np.zeros_like(img) + translate_from_left = True if random.random() <= 0.5 else False + if translate_from_left: + img_translated[th:, tw:] = img[: height - th, : width - tw] + else: + img_translated[: height - th, : width - tw] = img[th:, tw:] + data["image"] = img_translated + + if "mask" in data: + mask = data.pop("mask") + mask_translated = np.zeros_like(mask) + if translate_from_left: + mask_translated[th:, tw:] = mask[: height - th, : width - tw] + else: + mask_translated[: height - th, : width - tw] = mask[th:, tw:] + data["mask"] = mask_translated + return data + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-translate.enable", + action="store_true", + help="use random translation", + ) + group.add_argument( + "--image-augmentation.random-translate.factor", + type=float, + default=0.2, + help="Translate uniformly between (-u, u)", + ) + return parser + + def __repr__(self): + return "{}(factor={})".format(self.__class__.__name__, self.translation_factor) + + +@register_transformations(name="resize", type="image") +class Resize(BaseTransformation): + def __init__(self, opts, size, *args, **kwargs): + if not ( + isinstance(size, int) + or (isinstance(size, Sequence) and len(size) in (1, 2)) + ): + raise TypeError( + "Supported size args are int or tuple of length 2. Got inappropriate size arg: {}".format( + self.size + ) + ) + interpolation = getattr( + opts, "image_augmentation.resize.interpolation", "bilinear" + ) + super(Resize, self).__init__(opts=opts) + + self.size = size + self.interpolation = _cv2_interpolation(interpolation) + + def __call__(self, data: Dict) -> Dict: + return _resize_fn(data=data, size=self.size, interpolation=self.interpolation) + + def __repr__(self): + return "{}(size={}, interpolation={})".format( + self.__class__.__name__, + self.size, + _cv2_to_str_interpolation[self.interpolation], + ) + + +@register_transformations(name="box_absolute_coords", type="image") +class BoxAbsoluteCoords(BaseTransformation): + def __init__(self, opts): + super(BoxAbsoluteCoords, self).__init__(opts=opts) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.box-absolute-coords.enable", + action="store_true", + help="Convert box coordinates to absolute coordinates", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + + image = data["image"] + + height, width, channels = image.shape + boxes[..., 0::2] *= width + boxes[..., 1::2] *= height + + data["box_coordinates"] = boxes + return data + + +@register_transformations(name="box_percent_coords", type="image") +class BoxPercentCoords(BaseTransformation): + def __init__(self, opts): + super(BoxPercentCoords, self).__init__(opts=opts) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.box-percent-coords.enable", + action="store_true", + help="Convert box coordinates to percent", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + image = data["image"] + height, width, channels = image.shape + + boxes = boxes.astype(np.float) + + boxes[..., 0::2] /= width + boxes[..., 1::2] /= height + data["box_coordinates"] = boxes + + return data + + +@register_transformations(name="ssd_cropping", type="image") +class SSDCroping(BaseTransformation): + """Crop + Arguments: + img (Image): the image being input during training + boxes (Tensor): the original bounding boxes in pt form + labels (Tensor): the class labels for each bbox + mode (float tuple): the min and max jaccard overlaps + Return: + (img, boxes, classes) + img (Image): the cropped image + boxes (Tensor): the adjusted bounding boxes in pt form + labels (Tensor): the class labels for each bbox + """ + + def __init__(self, opts): + super(SSDCroping, self).__init__(opts=opts) + self.iou_sample_opts = getattr( + opts, + "image_augmentation.ssd_crop.iou_thresholds", + [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0], + ) + self.trials = getattr(opts, "image_augmentation.ssd_crop.n_trials", 40) + self.min_aspect_ratio = getattr( + opts, "image_augmentation.ssd_crop.min_aspect_ratio", 0.5 + ) + self.max_aspect_ratio = getattr( + opts, "image_augmentation.ssd_crop.max_aspect_ratio", 0.5 + ) + + def __call__(self, data: Dict) -> Dict: + if "box_coordinates" in data: + boxes = data["box_coordinates"] + + # guard against no boxes + if boxes.shape[0] == 0: + return data + + image = data["image"] + labels = data["box_labels"] + height, width = image.shape[:2] + + while True: + # randomly choose a mode + min_jaccard_overalp = random.choice(self.iou_sample_opts) + if min_jaccard_overalp == 0.0: + return data + + for _ in range(self.trials): + w = random.uniform(0.3 * width, width) + h = random.uniform(0.3 * height, height) + + aspect_ratio = h / w + if not ( + self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio + ): + continue + + left = random.uniform(0, width - w) + top = random.uniform(0, height - h) + + # convert to integer rect x1,y1,x2,y2 + rect = np.array([int(left), int(top), int(left + w), int(top + h)]) + + # calculate IoU (jaccard overlap) b/t the cropped and gt boxes + ious = jaccard_numpy(boxes, rect) + + # is min and max overlap constraint satisfied? if not try again + if ious.max() < min_jaccard_overalp: + continue + + # keep overlap with gt box IF center in sampled patch + centers = (boxes[:, :2] + boxes[:, 2:]) * 0.5 + + # mask in all gt boxes that above and to the left of centers + m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) + + # mask in all gt boxes that under and to the right of centers + m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) + + # mask in that both m1 and m2 are true + mask = m1 * m2 + + # have any valid boxes? try again if not + if not mask.any(): + continue + + # if image size is too small, try again + if (rect[3] - rect[1]) < 100 or (rect[2] - rect[0]) < 100: + continue + + # cut the crop from the image + image = image[rect[1] : rect[3], rect[0] : rect[2], :] + + # take only matching gt boxes + current_boxes = boxes[mask, :].copy() + + # take only matching gt labels + current_labels = labels[mask] + + # should we use the box left and top corner or the crop's + current_boxes[:, :2] = np.maximum(current_boxes[:, :2], rect[:2]) + # adjust to crop (by substracting crop's left,top) + current_boxes[:, :2] -= rect[:2] + + current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:], rect[2:]) + # adjust to crop (by substracting crop's left,top) + current_boxes[:, 2:] -= rect[:2] + + data["image"] = image + data["box_labels"] = current_labels + data["box_coordinates"] = current_boxes + + if "mask" in data: + seg_mask = data.pop("mask") + seg_mask = seg_mask[rect[1] : rect[3], rect[0] : rect[2]] + data["mask"] = seg_mask + + if "instance_mask" in data: + assert "instance_coords" in data + + instance_masks = data.pop("instance_mask") + instance_masks = instance_masks[ + rect[1] : rect[3], rect[0] : rect[2], ... + ] + data["instance_mask"] = instance_masks + + instance_coords = data.pop("instance_coords") + # should we use the box left and top corner or the crop's + instance_coords[..., :2] = np.maximum( + instance_coords[..., :2], rect[:2] + ) + # adjust to crop (by substracting crop's left,top) + instance_coords[..., :2] -= rect[:2] + + instance_coords[..., 2:] = np.minimum( + instance_coords[..., 2:], rect[2:] + ) + # adjust to crop (by substracting crop's left,top) + instance_coords[..., 2:] -= rect[:2] + data["instance_coords"] = instance_coords + + return data + return data + + +@register_transformations(name="center_crop", type="image") +class CenterCrop(BaseTransformation): + def __init__(self, opts, size: Sequence or int): + super(CenterCrop, self).__init__(opts=opts) + if isinstance(size, Sequence) and len(size) == 2: + self.height, self.width = size[0], size[1] + elif isinstance(size, Sequence) and len(size) == 1: + self.height = self.width = size[0] + elif isinstance(size, int): + self.height = self.width = size + else: + logger.error("Scale should be either an int or tuple of ints") + + def __call__(self, data: Dict) -> Dict: + height, width = data["image"].shape[:2] + i = (height - self.height) // 2 + j = (width - self.width) // 2 + return _crop_fn(data=data, i=i, j=j, h=self.height, w=self.width) + + def __repr__(self): + return "{}(size=(h={}, w={}))".format( + self.__class__.__name__, self.height, self.width + ) + + +@register_transformations(name="random_jpeg_compress", type="image") +class RandomJPEGCompress(BaseTransformation): + def __init__(self, opts): + q_range = getattr( + opts, "image_augmentation.random_jpeg_compress.q_factor", (5, 25) + ) + if isinstance(q_range, (int, float)): + q_range = (max(q_range - 10, 0), q_range) + assert len(q_range) == 2 + assert q_range[0] <= q_range[1] + p = getattr(opts, "image_augmentation.random_jpeg_compress.p", 0.5) + super(RandomJPEGCompress, self).__init__(opts=opts) + self.q_factor = q_range + self.p = p + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-jpeg-compress.enable", + action="store_true", + help="use random compression", + ) + group.add_argument( + "--image-augmentation.random-jpeg-compress.q-factor", + type=int or tuple, + default=(5, 25), + help="Compression quality factor range", + ) + group.add_argument( + "--image-augmentation.random-jpeg-compress.p", + type=float, + default=0.5, + help="Probability that {} will be applied".format(cls.__name__), + ) + + return parser + + def __call__(self, data: Dict) -> Dict: + if random.random() <= self.p: + q_factor = random.randint(self.q_factor[0], self.q_factor[1]) + encoding_param = [int(cv2.IMWRITE_JPEG_QUALITY), q_factor] + + img = data["image"] + _, enc_img = cv2.imencode(".jpg", img, encoding_param) + comp_img = cv2.imdecode(enc_img, 1) + data["image"] = comp_img + + return data + + def __repr__(self): + return "{}(q_factor=({}, {}), p={})".format( + self.__class__.__name__, self.q_factor[0], self.q_factor[1], self.p + ) + + +@register_transformations(name="random_gauss_noise", type="image") +class RandomGaussianNoise(BaseTransformation): + def __init__(self, opts): + sigma_range = getattr( + opts, "image_augmentation.random_gauss_noise.sigma", (0.03, 0.3) + ) + if isinstance(sigma_range, (float, int)): + sigma_range = (0, sigma_range) + + assert len(sigma_range) == 2, "Got {}".format(sigma_range) + assert sigma_range[0] <= sigma_range[1] + p = getattr(opts, "image_augmentation.random_gauss_noise.p", 0.5) + super(RandomGaussianNoise, self).__init__(opts=opts) + self.sigma_low = sigma_range[0] + self.sigma_high = sigma_range[1] + self.p = p + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-gauss-noise.enable", + action="store_true", + help="use random gaussian noise", + ) + group.add_argument( + "--image-augmentation.random-gauss-noise.sigma", + type=float or tuple, + default=(0.03, 0.1), + help="Sigma (sqrt of variance) range for Gaussian noise. Default is (0.0001, 0.001).", + ) + group.add_argument( + "--image-augmentation.random-gauss-noise.p", + type=float, + default=0.5, + help="Probability that {} will be applied".format(cls.__name__), + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if random.random() <= self.p: + std = random.uniform(self.sigma_low, self.sigma_high) + + img = data["image"] + noise = np.random.normal(0.0, std, img.shape) * 255 + noisy_img = img + noise + + noisy_img = np.clip(noisy_img, 0, 255).astype(np.uint8) + data["image"] = noisy_img + return data + + def __repr__(self): + return "{}(sigma=({}, {}), p={})".format( + self.__class__.__name__, self.sigma_low, self.sigma_high, self.p + ) + + +@register_transformations(name="to_tensor", type="image") +class NumpyToTensor(BaseTransformation): + def __init__(self, opts, *args, **kwargs): + super(NumpyToTensor, self).__init__(opts=opts) + + def __call__(self, data: Dict) -> Dict: + # HWC --> CHW + img = data["image"] + img = img.transpose(2, 0, 1) + img = np.ascontiguousarray(img) + + # numpy to tensor + img_tensor = torch.from_numpy(img).float() + img_tensor = torch.div(img_tensor, 255.0) + data["image"] = img_tensor + + if "mask" in data: + mask = data.pop("mask") + if len(mask.shape) > 2 and mask.shape[-1] > 1: + mask = mask.transpose(2, 0, 1) + data["mask"] = torch.from_numpy(mask).long() + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + data["box_coordinates"] = torch.from_numpy(boxes).float() + + if "box_labels" in data: + box_labels = data.pop("box_labels") + data["box_labels"] = torch.from_numpy(box_labels) + + if "instance_mask" in data: + assert "instance_coords" in data + + instance_masks = data.pop("instance_mask") + # [H, W, N] --> [N, H, W] + instance_masks = instance_masks.transpose(2, 0, 1) + instance_masks = np.ascontiguousarray(instance_masks) + data["instance_mask"] = torch.from_numpy(instance_masks).long() + + instance_coords = data.pop("instance_coords") + data["instance_coords"] = torch.from_numpy(instance_coords).float() + return data + + +@register_transformations(name="random_order", type="image") +class RandomOrder(BaseTransformation): + def __init__(self, opts, img_transforms: list): + super(RandomOrder, self).__init__(opts=opts) + self.transforms = img_transforms + apply_k_factor = getattr(opts, "image_augmentation.random_order.apply_k", 1.0) + assert ( + 0.0 < apply_k_factor <= 1.0 + ), "--image-augmentation.random-order.apply-k should be between 0 and 1" + self.keep_t = int(math.ceil(len(self.transforms) * apply_k_factor)) + + def __call__(self, data: Dict) -> Dict: + random.shuffle(self.transforms) + for t in self.transforms[: self.keep_t]: + data = t(data) + return data + + def __repr__(self): + transform_str = ", ".join(str(t) for t in self.transforms) + repr_str = "{}(n_transforms={}, t_list=[{}]".format( + self.__class__.__name__, self.keep_t, transform_str + ) + return repr_str + + +@register_transformations(name="compose", type="image") +class Compose(BaseTransformation): + def __init__(self, opts, img_transforms: list): + super(Compose, self).__init__(opts=opts) + self.img_transforms = img_transforms + + def __call__(self, data: Dict) -> Dict: + for t in self.img_transforms: + data = t(data) + return data + + def __repr__(self): + transform_str = ", ".join("\n\t\t\t" + str(t) for t in self.img_transforms) + repr_str = "{}({})".format(self.__class__.__name__, transform_str) + return repr_str + + +@register_transformations(name="photo_metric_distort_opencv", type="image") +class PhotometricDistort(BaseTransformation): + def __init__(self, opts): + beta_min = getattr( + opts, "image_augmentation.photo_metric_distort_opencv.beta_min", -0.2 + ) + beta_max = getattr( + opts, "image_augmentation.photo_metric_distort_opencv.beta_max", 0.2 + ) + assert -0.5 <= beta_min < beta_max <= 0.5, "Got {} and {}".format( + beta_min, beta_max + ) + + alpha_min = getattr( + opts, "image_augmentation.photo_metric_distort_opencv.alpha_min", 0.5 + ) + alpha_max = getattr( + opts, "image_augmentation.photo_metric_distort_opencv.alpha_max", 1.5 + ) + assert 0 < alpha_min < alpha_max, "Got {} and {}".format(alpha_min, alpha_max) + + gamma_min = getattr( + opts, "image_augmentation.photo_metric_distort_opencv.gamma_min", 0.5 + ) + gamma_max = getattr( + opts, "image_augmentation.photo_metric_distort_opencv.gamma_max", 1.5 + ) + assert 0 < gamma_min < gamma_max, "Got {} and {}".format(gamma_min, gamma_max) + + delta_min = getattr( + opts, "image_augmentation.photo_metric_distort_opencv.delta_min", -0.05 + ) + delta_max = getattr( + opts, "image_augmentation.photo_metric_distort_opencv.delta_max", 0.05 + ) + assert -1.0 < delta_min < delta_max < 1.0, "Got {} and {}".format( + delta_min, delta_max + ) + + super(PhotometricDistort, self).__init__(opts=opts) + # for briightness + self.beta_min = beta_min + self.beta_max = beta_max + # for contrast + self.alpha_min = alpha_min + self.alpha_max = alpha_max + # for saturation + self.gamma_min = gamma_min + self.gamma_max = gamma_max + # for hue + self.delta_min = delta_min + self.delta_max = delta_max + self.p = getattr(opts, "image_augmentation.photo_metric_distort_opencv.p", 0.5) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.enable", + action="store_true", + help="Randomly apply photometric transformation", + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.alpha-min", + type=float, + default=0.5, + help="Min. alpha value for contrast. Should be > 0", + ) + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.alpha-max", + type=float, + default=1.5, + help="Max. alpha value for contrast. Should be > 0", + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.beta-min", + type=float, + default=-0.2, + help="Min. alpha value for brightness. Should be between -1 and 1.", + ) + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.beta-max", + type=float, + default=0.2, + help="Max. alpha value for brightness. Should be between -1 and 1.", + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.gamma-min", + type=float, + default=0.5, + help="Min. alpha value for saturation. Should be > 0", + ) + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.gamma-max", + type=float, + default=1.5, + help="Max. alpha value for saturation. Should be > 0", + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.delta-min", + type=float, + default=-0.05, + help="Min. alpha value for Hue. Should be between -1 and 1.", + ) + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.delta-max", + type=float, + default=0.05, + help="Max. alpha value for Hue. Should be between -1 and 1.", + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort-opencv.p", + type=float, + default=0.5, + help="Prob of applying transformation", + ) + + return parser + + def apply_transformations(self, image): + def convert_to_uint8(img): + return np.clip(img, 0, 255).astype(np.uint8) + + rand_nums = np.random.rand(6) + + image = image.astype(np.float32) + + # apply random contrast + alpha = ( + random.uniform(self.alpha_min, self.alpha_max) + if rand_nums[0] < self.p + else 1.0 + ) + image *= alpha + + # Apply random brightness + beta = ( + (random.uniform(self.beta_min, self.beta_max) * 255) + if rand_nums[1] < self.p + else 0.0 + ) + image += beta + + image = convert_to_uint8(image) + image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + image = image.astype(np.float32) + + # Apply random saturation + gamma = ( + random.uniform(self.gamma_min, self.gamma_max) + if rand_nums[2] < self.p + else 1.0 + ) + image[..., 1] *= gamma + + # Apply random hue + delta = ( + int(random.uniform(self.delta_min, self.delta_max) * 255) + if rand_nums[3] < self.p + else 0.0 + ) + image[..., 0] += delta + + image = convert_to_uint8(image) + image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) + + if alpha == 1.0 and rand_nums[4] < self.p: + # apply contrast if earlier not applied + image = image.astype(np.float32) + alpha = random.uniform(self.alpha_min, self.alpha_max) + image *= alpha + image = convert_to_uint8(image) + + # Lightning noise + channels = image.shape[-1] + swap = np.random.permutation(range(channels)) if rand_nums[5] < self.p else None + if swap is not None: + image = image[..., swap] + + return image + + def __call__(self, data: Dict) -> Dict: + image = data.pop("image") + data["image"] = self.apply_transformations(image) + return data + + +# add by huangzp +@register_transformations(name="bit_plane", type="image") +class BitPlane(BaseTransformation): + def __init__(self, opts, h, w): + # min_size = getattr(opts, "image_augmentation.random_resize.min_size", 256) + # max_size = getattr(opts, "image_augmentation.random_resize.max_size", 1024) + # interpolation = getattr( + # opts, "image-augmentation.random_resize.interpolation", "bilinear" + # ) + super(BitPlane, self).__init__(opts=opts) + self.h = h + self.w = w + self.weight = np.int16(np.ones([h, w, 8, 3])) + self.bias = np.int16(np.ones([h, w, 8, 3])) + for i in range(8): + self.weight[:,:,i,:] = self.weight[:,:,i,:] * (2**(7-i)) + self.bias[:,:,i,:] = self.bias[:,:,i,:] * (2**i) + + def __call__(self, data: Dict) -> Dict: + img = data['image'] + new_img = (self.weight & img[:,:,None,:]) * self.bias + new_img = new_img.reshape(self.h, self.w, 24) + # h,w = img.shape[0], img.shape[1] + # new_img = np.zeros((h,w,24)) + # for c in range(3): + # for i in range(h): + # for j in range(w): + # n = str(np.binary_repr(img[i,j,c],8)) + # for k in range(8): + # new_img[i,j,3*k+c] = n[k] + # # TODO: to check it out + data['image'] = new_img + return data + + def __repr__(self): + return "{}(bit plane 3 3 3 3 3 3 3 3 sum 24 plane)".format( + self.__class__.__name__, + ) diff --git a/Adaptive Frequency Filters/data/transforms/image_pil.py b/Adaptive Frequency Filters/data/transforms/image_pil.py new file mode 100644 index 0000000..769d5df --- /dev/null +++ b/Adaptive Frequency Filters/data/transforms/image_pil.py @@ -0,0 +1,2159 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import copy +from PIL import Image, ImageFilter +from utils import logger +import numpy as np +import random +import torch +import math +import argparse +from torchvision import transforms as T +from torchvision.transforms import functional as F +from typing import Sequence, Dict, Any, Union, Tuple, List, Optional + +from . import register_transformations, BaseTransformation +from .utils import jaccard_numpy, setup_size + +INTERPOLATION_MODE_MAP = { + "nearest": T.InterpolationMode.NEAREST, + "bilinear": T.InterpolationMode.BILINEAR, + "bicubic": T.InterpolationMode.BICUBIC, + "cubic": T.InterpolationMode.BICUBIC, + "box": T.InterpolationMode.BOX, + "hamming": T.InterpolationMode.HAMMING, + "lanczos": T.InterpolationMode.LANCZOS, +} + + +def _interpolation_modes_from_str(name: str) -> T.InterpolationMode: + return INTERPOLATION_MODE_MAP[name] + + +def _crop_fn(data: Dict, top: int, left: int, height: int, width: int) -> Dict: + """Helper function for cropping""" + img = data["image"] + data["image"] = F.crop(img, top=top, left=left, height=height, width=width) + + if "mask" in data: + mask = data.pop("mask") + data["mask"] = F.crop(mask, top=top, left=left, height=height, width=width) + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + + area_before_cropping = (boxes[..., 2] - boxes[..., 0]) * ( + boxes[..., 3] - boxes[..., 1] + ) + + boxes[..., 0::2] = np.clip(boxes[..., 0::2] - left, a_min=0, a_max=left + width) + boxes[..., 1::2] = np.clip(boxes[..., 1::2] - top, a_min=0, a_max=top + height) + + area_after_cropping = (boxes[..., 2] - boxes[..., 0]) * ( + boxes[..., 3] - boxes[..., 1] + ) + area_ratio = area_after_cropping / (area_before_cropping + 1) + + # keep the boxes whose area is atleast 20% of the area before cropping + keep = area_ratio >= 0.2 + + box_labels = data.pop("box_labels") + + data["box_coordinates"] = boxes[keep] + data["box_labels"] = box_labels[keep] + + if "instance_mask" in data: + assert "instance_coords" in data + + instance_masks = data.pop("instance_mask") + data["instance_mask"] = F.crop( + instance_masks, top=top, left=left, height=height, width=width + ) + + instance_coords = data.pop("instance_coords") + instance_coords[..., 0::2] = np.clip( + instance_coords[..., 0::2] - left, a_min=0, a_max=left + width + ) + instance_coords[..., 1::2] = np.clip( + instance_coords[..., 1::2] - top, a_min=0, a_max=top + height + ) + data["instance_coords"] = instance_coords + + return data + + +def _resize_fn( + data: Dict, + size: Union[Sequence, int], + interpolation: Optional[T.InterpolationMode or str] = T.InterpolationMode.BILINEAR, +) -> Dict: + """Helper function for resizing""" + img = data["image"] + + w, h = F.get_image_size(img) + + if isinstance(size, Sequence) and len(size) == 2: + size_h, size_w = size[0], size[1] + elif isinstance(size, int): + if (w <= h and w == size) or (h <= w and h == size): + return data + + if w < h: + size_h = int(size * h / w) + + size_w = size + else: + size_w = int(size * w / h) + size_h = size + else: + raise TypeError( + "Supported size args are int or tuple of length 2. Got inappropriate size arg: {}".format( + size + ) + ) + + if isinstance(interpolation, str): + interpolation = _interpolation_modes_from_str(name=interpolation) + + data["image"] = F.resize( + img=img, size=[size_h, size_w], interpolation=interpolation + ) + + if "mask" in data: + mask = data.pop("mask") + # mask can be a PIL or Tensor. + # Especially for Mask-RCNN, we may have tensors with first dimension as 0. + # In that case, resize, won't work. + # A workaround is that we check for the instance of a Tensor and then check its dimension. + if isinstance(mask, torch.Tensor) and mask.shape[0] == 0: + # It's empty tensor. + resized_mask = torch.zeros( + [0, size_h, size_w], dtype=mask.dtype, device=mask.device + ) + else: + resized_mask = F.resize( + img=mask, + size=[size_h, size_w], + interpolation=T.InterpolationMode.NEAREST, + ) + data["mask"] = resized_mask + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + boxes[:, 0::2] *= 1.0 * size_w / w + boxes[:, 1::2] *= 1.0 * size_h / h + data["box_coordinates"] = boxes + + if "instance_mask" in data: + assert "instance_coords" in data + + instance_masks = data.pop("instance_mask") + + resized_instance_masks = F.resize( + img=instance_masks, + size=[size_h, size_w], + interpolation=T.InterpolationMode.NEAREST, + ) + data["instance_mask"] = resized_instance_masks + + instance_coords = data.pop("instance_coords") + instance_coords = instance_coords.astype(np.float) + instance_coords[..., 0::2] *= 1.0 * size_w / w + instance_coords[..., 1::2] *= 1.0 * size_h / h + data["instance_coords"] = instance_coords + + return data + + +def _pad_fn( + data: Dict, + padding: Union[int, Sequence], + fill: Optional[int] = 0, + padding_mode: Optional[str] = "constant", +) -> Dict: + # Taken from the functional_tensor.py pad + if isinstance(padding, int): + pad_left = pad_right = pad_top = pad_bottom = padding + elif len(padding) == 1: + pad_left = pad_right = pad_top = pad_bottom = padding[0] + elif len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + else: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + + padding = [pad_left, pad_top, pad_right, pad_bottom] + data["image"] = F.pad(data.pop("image"), padding, fill, padding_mode) + + if "mask" in data: + data["mask"] = F.pad(data.pop("mask"), padding, 0, "constant") + + if "box_coordinates" in data: + # labels remain unchanged + boxes = data.pop("box_coordinates") + boxes[:, 0::2] += pad_left + boxes[:, 1::2] += pad_top + data["box_coordinates"] = boxes + + return data + + +@register_transformations(name="fixed_size_crop", type="image_pil") +class FixedSizeCrop(BaseTransformation): + def __init__( + self, opts, size: Optional[Union[int, Tuple[int, int]]] = None, *args, **kwargs + ): + super().__init__(opts, *args, **kwargs) + # size can be passed as an argument or using config. + # The argument is useful when implementing variable samplers + if size is None: + size = getattr(opts, "image_augmentation.fixed_size_crop.size", None) + fill = getattr(opts, "image_augmentation.fixed_size_crop.fill", 0) + padding_mode = getattr( + opts, "image_augmentation.fixed_size_crop.padding_mode", "constant" + ) + size = setup_size( + size, + error_msg="Please provide either int or (int, int) for size in {}.".format( + self.__class__.__name__ + ), + ) + self.crop_height = size[0] + self.crop_width = size[1] + self.fill = fill + self.padding_mode = padding_mode + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.fixed-size-crop.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.fixed-size-crop.size", + type=int, + nargs="+", + default=None, + help="Image size either as an int or (int, int).", + ) + group.add_argument( + "--image-augmentation.fixed-size-crop.fill", + type=int, + default=0, + help="Fill value to be used during padding operation. Defaults to 0.", + ) + group.add_argument( + "--image-augmentation.fixed-size-crop.padding-mode", + type=str, + default="constant", + help="Padding modes. Defaults to constant", + ) + + return parser + + def __call__(self, data: Dict, *args, **kwargs) -> Dict: + img = data["image"] + width, height = F.get_image_size(img) + new_height = min(height, self.crop_height) + new_width = min(width, self.crop_width) + + if new_height != height or new_width != width: + offset_height = max(height - self.crop_height, 0) + offset_width = max(width - self.crop_width, 0) + + r = random.random() + top = int(offset_height * r) + left = int(offset_width * r) + + data = _crop_fn( + data, top=top, left=left, height=new_height, width=new_width + ) + + pad_bottom = max(self.crop_height - new_height, 0) + pad_right = max(self.crop_width - new_width, 0) + if pad_bottom != 0 or pad_right != 0: + data = _pad_fn( + data, + padding=[0, 0, pad_right, pad_bottom], + fill=self.fill, + padding_mode=self.padding_mode, + ) + return data + + def __repr__(self): + return "{}(crop_size=({}, {}), fill={}, padding_mode={})".format( + self.__class__.__name__, + self.crop_height, + self.crop_width, + self.fill, + self.padding_mode, + ) + + +@register_transformations(name="scale_jitter", type="image_pil") +class ScaleJitter(BaseTransformation): + """Randomly resizes the input within the scale range""" + + def __init__(self, opts, *args, **kwargs) -> None: + target_size = getattr(opts, "image_augmentation.scale_jitter.target_size", None) + if target_size is None: + logger.error( + "Target size can't be None in {}.".format(self.__class__.__name__) + ) + target_size = setup_size( + target_size, + error_msg="Need either an int or (int, int) for target size in {}".format( + self.__class__.__name__ + ), + ) + + scale_range = getattr(opts, "image_augmentation.scale_jitter.scale_range", None) + if scale_range is None: + logger.error( + "Scale range can't be None in {}".format(self.__class__.__name__) + ) + + if isinstance(scale_range, Sequence) and len(scale_range) == 2: + scale_range = scale_range + else: + logger.error( + "Need (float, float) for target size in {}".format( + self.__class__.__name__ + ) + ) + + if scale_range[0] > scale_range[1]: + logger.error( + "scale_range[1] >= scale_range[0] in {}. Got: {}".format( + self.__class__.__name__, scale_range[1], scale_range[0] + ) + ) + + interpolation = getattr( + opts, "image_augmentation.scale_jitter.interpolation", "bilinear" + ) + + if isinstance(interpolation, str): + interpolation = _interpolation_modes_from_str(name=interpolation) + + super().__init__(opts, *args, **kwargs) + self.target_size = target_size + self.scale_range = scale_range + self.interpolation = interpolation + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.scale-jitter.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.scale-jitter.interpolation", + type=str, + default="bilinear", + help="Interpolation method. Defaults to bilinear interpolation", + ) + group.add_argument( + "--image-augmentation.scale-jitter.target-size", + type=int, + nargs="+", + default=None, + help="Target image size either as an int or (int, int).", + ) + group.add_argument( + "--image-augmentation.scale-jitter.scale-range", + type=float, + nargs="+", + default=None, + help="Scale range as (float, float).", + ) + + return parser + + def __call__(self, data: Dict, *args, **kwargs) -> Dict: + img = data["image"] + orig_width, orig_height = F.get_image_size(img) + scale = self.scale_range[0] + random.random() * ( + self.scale_range[1] - self.scale_range[0] + ) + r = ( + min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) + * scale + ) + new_width = int(orig_width * r) + new_height = int(orig_height * r) + + data = _resize_fn( + data, size=(new_height, new_width), interpolation=self.interpolation + ) + return data + + def __repr__(self): + return "{}(scale_range={}, target_size={}, interpolation={})".format( + self.__class__.__name__, + self.scale_range, + self.target_size, + self.interpolation, + ) + + +@register_transformations(name="random_resized_crop", type="image_pil") +class RandomResizedCrop(BaseTransformation, T.RandomResizedCrop): + """ + This class crops a random portion of an image and resize it to a given size. + """ + + def __init__(self, opts, size: Union[Sequence, int], *args, **kwargs) -> None: + interpolation = getattr( + opts, "image_augmentation.random_resized_crop.interpolation", "bilinear" + ) + scale = getattr( + opts, "image_augmentation.random_resized_crop.scale", (0.08, 1.0) + ) + ratio = getattr( + opts, + "image_augmentation.random_resized_crop.aspect_ratio", + (3.0 / 4.0, 4.0 / 3.0), + ) + + BaseTransformation.__init__(self, opts=opts) + + T.RandomResizedCrop.__init__( + self, size=size, scale=scale, ratio=ratio, interpolation=interpolation + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.random-resized-crop.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.random-resized-crop.interpolation", + type=str, + default="bilinear", + choices=list(INTERPOLATION_MODE_MAP.keys()), + help="Interpolation method for resizing. Defaults to bilinear.", + ) + group.add_argument( + "--image-augmentation.random-resized-crop.scale", + type=tuple, + default=(0.08, 1.0), + help="Specifies the lower and upper bounds for the random area of the crop, before resizing." + " The scale is defined with respect to the area of the original image. Defaults to " + "(0.08, 1.0)", + ) + group.add_argument( + "--image-augmentation.random-resized-crop.aspect-ratio", + type=float or tuple, + default=(3.0 / 4.0, 4.0 / 3.0), + help="lower and upper bounds for the random aspect ratio of the crop, before resizing. " + "Defaults to (3./4., 4./3.)", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + img = data["image"] + i, j, h, w = super().get_params(img=img, scale=self.scale, ratio=self.ratio) + data = _crop_fn(data=data, top=i, left=j, height=h, width=w) + return _resize_fn(data=data, size=self.size, interpolation=self.interpolation) + + def __repr__(self) -> str: + return "{}(scale={}, ratio={}, size={}, interpolation={})".format( + self.__class__.__name__, + self.scale, + self.ratio, + self.size, + self.interpolation, + ) + + +@register_transformations(name="auto_augment", type="image_pil") +class AutoAugment(BaseTransformation, T.AutoAugment): + """ + This class implements the `AutoAugment data augmentation `_ method. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + policy_name = getattr( + opts, "image_augmentation.auto_augment.policy", "imagenet" + ) + interpolation = getattr( + opts, "image_augmentation.auto_augment.interpolation", "bilinear" + ) + if policy_name == "imagenet": + policy = T.AutoAugmentPolicy.IMAGENET + else: + raise NotImplemented + + if isinstance(interpolation, str): + interpolation = _interpolation_modes_from_str(name=interpolation) + + BaseTransformation.__init__(self, opts=opts) + T.AutoAugment.__init__(self, policy=policy, interpolation=interpolation) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.auto-augment.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.auto-augment.policy", + type=str, + default="imagenet", + help="Auto-augment policy name. Defaults to imagenet.", + ) + group.add_argument( + "--image-augmentation.auto-augment.interpolation", + type=str, + default="bilinear", + help="Auto-augment interpolation method. Defaults to bilinear interpolation", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if "box_coordinates" in data or "mask" in data or "instance_masks" in data: + logger.error( + "{} is only supported for classification tasks".format( + self.__class__.__name__ + ) + ) + + img = data["image"] + img = super().forward(img) + data["image"] = img + return data + + def __repr__(self) -> str: + return "{}(policy={}, interpolation={})".format( + self.__class__.__name__, self.policy, self.interpolation + ) + + +@register_transformations(name="rand_augment", type="image_pil") +class RandAugment(BaseTransformation, T.RandAugment): + """ + This class implements the `RandAugment data augmentation `_ method. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + num_ops = getattr(opts, "image_augmentation.rand_augment.num_ops", 2) + magnitude = getattr(opts, "image_augmentation.rand_augment.magnitude", 9) + num_magnitude_bins = getattr( + opts, "image_augmentation.rand_augment.num_magnitude_bins", 31 + ) + interpolation = getattr( + opts, "image_augmentation.rand_augment.interpolation", "bilinear" + ) + + BaseTransformation.__init__(self, opts=opts) + + if isinstance(interpolation, str): + interpolation = _interpolation_modes_from_str(name=interpolation) + + T.RandAugment.__init__( + self, + num_ops=num_ops, + magnitude=magnitude, + num_magnitude_bins=num_magnitude_bins, + interpolation=interpolation, + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.rand-augment.enable", + action="store_true", + help="Use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.rand-augment.num-ops", + type=int, + default=2, + help="Number of augmentation transformations to apply sequentially. Defaults to 2.", + ) + group.add_argument( + "--image-augmentation.rand-augment.magnitude", + type=int, + default=9, + help="Magnitude for all the transformations. Defaults to 9", + ) + group.add_argument( + "--image-augmentation.rand-augment.num-magnitude-bins", + type=int, + default=31, + help="The number of different magnitude values. Defaults to 31.", + ) + group.add_argument( + "--image-augmentation.rand-augment.interpolation", + type=str, + default="bilinear", + choices=list(INTERPOLATION_MODE_MAP.keys()), + help="Desired interpolation method. Defaults to bilinear", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if "box_coordinates" in data or "mask" in data or "instance_masks" in data: + logger.error( + "{} is only supported for classification tasks".format( + self.__class__.__name__ + ) + ) + + img = data["image"] + img = super().forward(img) + data["image"] = img + return data + + def __repr__(self) -> str: + return "{}(num_ops={}, magnitude={}, num_magnitude_bins={}, interpolation={})".format( + self.__class__.__name__, + self.num_ops, + self.magnitude, + self.num_magnitude_bins, + self.interpolation, + ) + + +@register_transformations(name="trivial_augment_wide", type="image_pil") +class TrivialAugmentWide(BaseTransformation, T.TrivialAugmentWide): + """ + This class implements the `TrivialAugment (Wide) data augmentation `_ method. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + num_magnitude_bins = getattr( + opts, "image_augmentation.trivial_augment_wide.num_magnitude_bins", 31 + ) + interpolation = getattr( + opts, "image_augmentation.trivial_augment_wide.interpolation", "bilinear" + ) + + BaseTransformation.__init__(self, opts=opts) + + if isinstance(interpolation, str): + interpolation = _interpolation_modes_from_str(name=interpolation) + + T.TrivialAugmentWide.__init__( + self, + num_magnitude_bins=num_magnitude_bins, + interpolation=interpolation, + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.trivial-augment-wide.enable", + action="store_true", + help="Use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.trivial-augment-wide.num-magnitude-bins", + type=int, + default=31, + help="The number of different magnitude values. Defaults to 31.", + ) + group.add_argument( + "--image-augmentation.trivial-augment-wide.interpolation", + type=str, + default="bilinear", + choices=list(INTERPOLATION_MODE_MAP.keys()), + help="Desired interpolation method. Defaults to bilinear", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if "box_coordinates" in data or "mask" in data or "instance_masks" in data: + logger.error( + "{} is only supported for classification tasks".format( + self.__class__.__name__ + ) + ) + + img = data["image"] + img = super().forward(img) + data["image"] = img + return data + + def __repr__(self) -> str: + return "{}(num_magnitude_bins={}, interpolation={})".format( + self.__class__.__name__, + self.num_magnitude_bins, + self.interpolation, + ) + + +@register_transformations(name="random_horizontal_flip", type="image_pil") +class RandomHorizontalFlip(BaseTransformation): + """ + This class implements random horizontal flipping method + """ + + def __init__(self, opts, *args, **kwargs) -> None: + p = getattr(opts, "image_augmentation.random_horizontal_flip.p", 0.5) + super().__init__(opts=opts) + self.p = p + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-horizontal-flip.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.random-horizontal-flip.p", + type=float, + default=0.5, + help="Probability for applying random horizontal flip", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if random.random() <= self.p: + img = data["image"] + width, height = F.get_image_size(img) + data["image"] = F.hflip(img) + + if "mask" in data: + mask = data.pop("mask") + data["mask"] = F.hflip(mask) + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + boxes[..., 0::2] = width - boxes[..., 2::-2] + data["box_coordinates"] = boxes + + if "instance_mask" in data: + assert "instance_coords" in data + + instance_coords = data.pop("instance_coords") + instance_coords[..., 0::2] = width - instance_coords[..., 2::-2] + data["instance_coords"] = instance_coords + + instance_masks = data.pop("instance_mask") + data["instance_mask"] = F.hflip(instance_masks) + return data + + def __repr__(self) -> str: + return "{}(p={})".format(self.__class__.__name__, self.p) + + +@register_transformations(name="random_rotate", type="image_pil") +class RandomRotate(BaseTransformation): + """ + This class implements random rotation method + """ + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts=opts) + self.angle = getattr(opts, "image_augmentation.random_rotate.angle", 10) + self.mask_fill = getattr(opts, "image_augmentation.random_rotate.mask_fill", 0) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-rotate.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.random-rotate.angle", + type=float, + default=10, + help="Angle for rotation. Defaults to 10. The angle is sampled " + "uniformly from [-angle, angle]", + ) + group.add_argument( + "--image-augmentation.random-rotate.mask-fill", + default=0, + help="Fill value for the segmentation mask. Defaults to 0.", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + + data_keys = list(data.keys()) + if "box_coordinates" in data_keys or "instance_mask" in data_keys: + logger.error("{} supports only images and masks") + + rand_angle = random.uniform(-self.angle, self.angle) + img = data.pop("image") + data["image"] = F.rotate( + img, angle=rand_angle, interpolation=F.InterpolationMode.BILINEAR, fill=0 + ) + if "mask" in data: + mask = data.pop("mask") + data["mask"] = F.rotate( + mask, + angle=rand_angle, + interpolation=F.InterpolationMode.NEAREST, + fill=self.mask_fill, + ) + return data + + def __repr__(self) -> str: + return "{}(angle={}, mask_fill={})".format( + self.__class__.__name__, self.angle, self.mask_fill + ) + + +@register_transformations(name="resize", type="image_pil") +class Resize(BaseTransformation): + """ + This class implements resizing operation. + + .. note:: + Two possible modes for resizing. + 1. Resize while maintaining aspect ratio. To enable this option, pass int as a size + 2. Resize to a fixed size. To enable this option, pass a tuple of height and width as a size + + .. note:: + If img_size is passed as a positional argument, then it will override size from args + """ + + def __init__( + self, + opts, + img_size: Optional[Union[Tuple[int, int], int]] = None, + *args, + **kwargs + ) -> None: + interpolation = getattr( + opts, "image_augmentation.resize.interpolation", "bilinear" + ) + super().__init__(opts=opts) + + # img_size argument is useful for implementing multi-scale sampler + size = ( + getattr(opts, "image_augmentation.resize.size", None) + if img_size is None + else img_size + ) + if size is None: + logger.error("Size can not be None in {}".format(self.__class__.__name__)) + + # Possible modes. + # 1. Resize while maintaining aspect ratio. To enable this option, pass int as a size + # 2. Resize to a fixed size. To enable this option, pass a tuple of height and width as a size + + if isinstance(size, Sequence) and len(size) == 1: + # List with single integer + size = size[0] + elif isinstance(size, Sequence) and len(size) > 2: + logger.error( + "The length of size should be either 1 or 2 in {}. Got: {}".format( + self.__class__.__name__, size + ) + ) + + if not (isinstance(size, Sequence) or isinstance(size, int)): + logger.error( + "Size needs to be either Tuple of length 2 or an integer in {}. Got: {}".format( + self.__class__.__name__, size + ) + ) + + self.size = size + self.interpolation = interpolation + self.maintain_aspect_ratio = True if isinstance(size, int) else False + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.resize.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.resize.interpolation", + type=str, + default="bilinear", + choices=list(INTERPOLATION_MODE_MAP.keys()), + help="Desired interpolation method for resizing. Defaults to bilinear", + ) + group.add_argument( + "--image-augmentation.resize.size", + type=int, + nargs="+", + default=256, + help="Resize image to the specified size. If int is passed, then shorter side is resized" + "to the specified size and longest side is resized while maintaining aspect ratio." + "Defaults to None.", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + return _resize_fn(data, size=self.size, interpolation=self.interpolation) + + def __repr__(self) -> str: + return "{}(size={}, interpolation={}, maintain_aspect_ratio={})".format( + self.__class__.__name__, + self.size, + self.interpolation, + self.maintain_aspect_ratio, + ) + + +@register_transformations(name="center_crop", type="image_pil") +class CenterCrop(BaseTransformation): + """ + This class implements center cropping method. + + .. note:: + This class assumes that the input size is greater than or equal to the desired size. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts=opts) + size = getattr(opts, "image_augmentation.center_crop.size", None) + + if size is None: + logger.error("Size cannot be None in {}".format(self.__class__.__name__)) + + if isinstance(size, Sequence) and len(size) == 2: + self.height, self.width = size[0], size[1] + elif isinstance(size, Sequence) and len(size) == 1: + self.height = self.width = size[0] + elif isinstance(size, int): + self.height = self.width = size + else: + logger.error("Scale should be either an int or tuple of ints") + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.center-crop.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.center-crop.size", + type=int, + nargs="+", + default=224, + help="Center crop size. Defaults to None.", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + width, height = F.get_image_size(data["image"]) + i = (height - self.height) // 2 + j = (width - self.width) // 2 + return _crop_fn(data=data, top=i, left=j, height=self.height, width=self.width) + + def __repr__(self) -> str: + return "{}(size=(h={}, w={}))".format( + self.__class__.__name__, self.height, self.width + ) + + +@register_transformations(name="ssd_cropping", type="image_pil") +class SSDCroping(BaseTransformation): + """ + This class implements cropping method for `Single shot object detector `_. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts=opts) + + self.iou_sample_opts = getattr( + opts, + "image_augmentation.ssd_crop.iou_thresholds", + [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0], + ) + self.trials = getattr(opts, "image_augmentation.ssd_crop.n_trials", 40) + self.min_aspect_ratio = getattr( + opts, "image_augmentation.ssd_crop.min_aspect_ratio", 0.5 + ) + self.max_aspect_ratio = getattr( + opts, "image_augmentation.ssd_crop.max_aspect_ratio", 2.0 + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.ssd-crop.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.ssd-crop.iou-thresholds", + type=float, + nargs="+", + default=[0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0], + help="IoU thresholds for SSD cropping. Defaults to [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]", + ) + group.add_argument( + "--image-augmentation.ssd-crop.n-trials", + type=int, + default=40, + help="Number of trials for SSD cropping. Defaults to 40", + ) + group.add_argument( + "--image-augmentation.ssd-crop.min-aspect-ratio", + type=float, + default=0.5, + help="Min. aspect ratio in SSD Cropping. Defaults to 0.5", + ) + group.add_argument( + "--image-augmentation.ssd-crop.max-aspect-ratio", + type=float, + default=2.0, + help="Max. aspect ratio in SSD Cropping. Defaults to 2.0", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if "box_coordinates" in data: + boxes = data["box_coordinates"] + + # guard against no boxes + if boxes.shape[0] == 0: + return data + + image = data["image"] + labels = data["box_labels"] + width, height = F.get_image_size(image) + + while True: + # randomly choose a mode + min_jaccard_overalp = random.choice(self.iou_sample_opts) + if min_jaccard_overalp == 0.0: + return data + + for _ in range(self.trials): + new_w = int(random.uniform(0.3 * width, width)) + new_h = int(random.uniform(0.3 * height, height)) + + aspect_ratio = new_h / new_w + if not ( + self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio + ): + continue + + left = int(random.uniform(0, width - new_w)) + top = int(random.uniform(0, height - new_h)) + + # convert to integer rect x1,y1,x2,y2 + rect = np.array([left, top, left + new_w, top + new_h]) + + # calculate IoU (jaccard overlap) b/t the cropped and gt boxes + ious = jaccard_numpy(boxes, rect) + + # is min and max overlap constraint satisfied? if not try again + if ious.max() < min_jaccard_overalp: + continue + + # keep overlap with gt box IF center in sampled patch + centers = (boxes[:, :2] + boxes[:, 2:]) * 0.5 + + # mask in all gt boxes that above and to the left of centers + m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) + + # mask in all gt boxes that under and to the right of centers + m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) + + # mask in that both m1 and m2 are true + mask = m1 * m2 + + # have any valid boxes? try again if not + if not mask.any(): + continue + + # if image size is too small, try again + if (rect[3] - rect[1]) < 100 or (rect[2] - rect[0]) < 100: + continue + + # cut the crop from the image + image = F.crop(image, top=top, left=left, width=new_w, height=new_h) + + # take only matching gt boxes + current_boxes = boxes[mask, :].copy() + + # take only matching gt labels + current_labels = labels[mask] + + # should we use the box left and top corner or the crop's + current_boxes[:, :2] = np.maximum(current_boxes[:, :2], rect[:2]) + # adjust to crop (by substracting crop's left,top) + current_boxes[:, :2] -= rect[:2] + + current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:], rect[2:]) + # adjust to crop (by substracting crop's left,top) + current_boxes[:, 2:] -= rect[:2] + + data["image"] = image + data["box_labels"] = current_labels + data["box_coordinates"] = current_boxes + + if "mask" in data: + mask = data.pop("mask") + data["mask"] = F.crop( + mask, top=top, left=left, width=new_w, height=new_h + ) + + if "instance_mask" in data: + assert "instance_coords" in data + instance_masks = data.pop("instance_mask") + data["instance_mask"] = F.crop( + instance_masks, + top=top, + left=left, + width=new_w, + height=new_h, + ) + + instance_coords = data.pop("instance_coords") + # should we use the box left and top corner or the crop's + instance_coords[..., :2] = np.maximum( + instance_coords[..., :2], rect[:2] + ) + # adjust to crop (by substracting crop's left,top) + instance_coords[..., :2] -= rect[:2] + + instance_coords[..., 2:] = np.minimum( + instance_coords[..., 2:], rect[2:] + ) + # adjust to crop (by substracting crop's left,top) + instance_coords[..., 2:] -= rect[:2] + data["instance_coords"] = instance_coords + + return data + return data + + +@register_transformations(name="photo_metric_distort", type="image_pil") +class PhotometricDistort(BaseTransformation): + """ + This class implements Photometeric distorion. + + .. note:: + Hyper-parameters of PhotoMetricDistort in PIL and OpenCV are different. Be careful + """ + + def __init__(self, opts, *args, **kwargs) -> None: + # contrast + alpha_min = getattr( + opts, "image_augmentation.photo_metric_distort.alpha_min", 0.5 + ) + alpha_max = getattr( + opts, "image_augmentation.photo_metric_distort.alpha_max", 1.5 + ) + contrast = T.ColorJitter(contrast=[alpha_min, alpha_max]) + + # brightness + beta_min = getattr( + opts, "image_augmentation.photo_metric_distort.beta_min", 0.875 + ) + beta_max = getattr( + opts, "image_augmentation.photo_metric_distort.beta_max", 1.125 + ) + brightness = T.ColorJitter(brightness=[beta_min, beta_max]) + + # saturation + gamma_min = getattr( + opts, "image_augmentation.photo_metric_distort.gamma_min", 0.5 + ) + gamma_max = getattr( + opts, "image_augmentation.photo_metric_distort.gamma_max", 1.5 + ) + saturation = T.ColorJitter(saturation=[gamma_min, gamma_max]) + + # Hue + delta_min = getattr( + opts, "image_augmentation.photo_metric_distort.delta_min", -0.05 + ) + delta_max = getattr( + opts, "image_augmentation.photo_metric_distort.delta_max", 0.05 + ) + hue = T.ColorJitter(hue=[delta_min, delta_max]) + + super().__init__(opts=opts) + self._brightness = brightness + self._contrast = contrast + self._hue = hue + self._saturation = saturation + self.p = getattr(opts, "image_augmentation.photo_metric_distort.p", 0.5) + + def __repr__(self) -> str: + return "{}(contrast={}, brightness={}, saturation={}, hue={})".format( + self.__class__.__name__, + self._contrast.contrast, + self._brightness.brightness, + self._saturation.saturation, + self._hue.hue, + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.photo-metric-distort.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort.alpha-min", + type=float, + default=0.5, + help="Min. alpha value for contrast. Should be > 0. Defaults to 0.5", + ) + group.add_argument( + "--image-augmentation.photo-metric-distort.alpha-max", + type=float, + default=1.5, + help="Max. alpha value for contrast. Should be > 0. Defaults to 1.5", + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort.beta-min", + type=float, + default=0.875, + help="Min. beta value for brightness. Should be > 0. Defaults to 0.8", + ) + group.add_argument( + "--image-augmentation.photo-metric-distort.beta-max", + type=float, + default=1.125, + help="Max. beta value for brightness. Should be > 0. Defaults to 1.2", + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort.gamma-min", + type=float, + default=0.5, + help="Min. gamma value for saturation. Should be > 0. Defaults to 0.5", + ) + group.add_argument( + "--image-augmentation.photo-metric-distort.gamma-max", + type=float, + default=1.5, + help="Max. gamma value for saturation. Should be > 0. Defaults to 1.5", + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort.delta-min", + type=float, + default=-0.05, + help="Min. delta value for Hue. Should be between -1 and 1. Defaults to -0.05", + ) + group.add_argument( + "--image-augmentation.photo-metric-distort.delta-max", + type=float, + default=0.05, + help="Max. delta value for Hue. Should be between -1 and 1. Defaults to 0.05", + ) + + group.add_argument( + "--image-augmentation.photo-metric-distort.p", + type=float, + default=0.5, + help="Probability for applying a distortion. Defaults to 0.5", + ) + + return parser + + def _apply_transformations(self, image): + r = np.random.rand(7) + + if r[0] < self.p: + image = self._brightness(image) + + contrast_before = r[1] < self.p + if contrast_before and r[2] < self.p: + image = self._contrast(image) + + if r[3] < self.p: + image = self._saturation(image) + + if r[4] < self.p: + image = self._hue(image) + + if not contrast_before and r[5] < self.p: + image = self._contrast(image) + + if r[6] < self.p and image.mode != "L": + # Only permute channels for RGB images + # [H, W, C] format + image_np = np.asarray(image) + n_channels = image_np.shape[2] + image_np = image_np[..., np.random.permutation(range(n_channels))] + image = Image.fromarray(image_np) + return image + + def __call__(self, data: Dict) -> Dict: + image = data.pop("image") + data["image"] = self._apply_transformations(image) + return data + + +@register_transformations(name="box_percent_coords", type="image_pil") +class BoxPercentCoords(BaseTransformation): + """ + This class converts the box coordinates to percent + """ + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts=opts) + + def __call__(self, data: Dict) -> Dict: + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + image = data["image"] + width, height = F.get_image_size(image) + + boxes = boxes.astype(np.float) + + boxes[..., 0::2] /= width + boxes[..., 1::2] /= height + data["box_coordinates"] = boxes + + return data + + +@register_transformations(name="instance_processor", type="image_pil") +class InstanceProcessor(BaseTransformation): + """ + This class processes the instance masks. + """ + + def __init__( + self, + opts, + instance_size: Optional[Union[int, Tuple[int, ...]]] = 16, + *args, + **kwargs + ) -> None: + super().__init__(opts=opts) + self.instance_size = setup_size(instance_size) + + def __call__(self, data: Dict) -> Dict: + + if "instance_mask" in data: + assert "instance_coords" in data + instance_masks = data.pop("instance_mask") + instance_coords = data.pop("instance_coords") + instance_coords = instance_coords.astype(np.int) + + valid_boxes = (instance_coords[..., 3] > instance_coords[..., 1]) & ( + instance_coords[..., 2] > instance_coords[..., 0] + ) + instance_masks = instance_masks[valid_boxes] + instance_coords = instance_coords[valid_boxes] + + num_instances = instance_masks.shape[0] + + resized_instances = [] + for i in range(num_instances): + # format is [N, H, W] + instance_m = instance_masks[i] + box_coords = instance_coords[i] + + instance_m = F.crop( + instance_m, + top=box_coords[1], + left=box_coords[0], + height=box_coords[3] - box_coords[1], + width=box_coords[2] - box_coords[0], + ) + # need to unsqueeze and squeeze to make F.resize work + instance_m = F.resize( + instance_m.unsqueeze(0), + size=self.instance_size, + interpolation=T.InterpolationMode.NEAREST, + ).squeeze(0) + resized_instances.append(instance_m) + + if len(resized_instances) == 0: + resized_instances = torch.zeros( + size=(1, self.instance_size[0], self.instance_size[1]), + dtype=torch.long, + ) + instance_coords = np.array( + [[0, 0, self.instance_size[0], self.instance_size[1]]] + ) + else: + resized_instances = torch.stack(resized_instances, dim=0) + + data["instance_mask"] = resized_instances + data["instance_coords"] = instance_coords.astype(np.float) + return data + + +@register_transformations(name="random_resize", type="image_pil") +class RandomResize(BaseTransformation): + """ + This class implements random resizing method. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + min_ratio = getattr(opts, "image_augmentation.random_resize.min_ratio", 0.5) + max_ratio = getattr(opts, "image_augmentation.random_resize.max_ratio", 2.0) + interpolation = getattr( + opts, "image_augmentation.random_resize.interpolation", "bilinear" + ) + + max_scale_long_edge = getattr( + opts, "image_augmentation.random_resize.max_scale_long_edge", None + ) + max_scale_short_edge = getattr( + opts, "image_augmentation.random_resize.max_scale_short_edge", None + ) + + if max_scale_long_edge is None and max_scale_short_edge is not None: + logger.warning( + "max_scale_long_edge cannot be none when max_scale_short_edge is not None in {}. Setting both to " + "None".format(self.__class__.__name__) + ) + max_scale_long_edge = None + max_scale_short_edge = None + elif max_scale_long_edge is not None and max_scale_short_edge is None: + logger.warning( + "max_scale_short_edge cannot be none when max_scale_long_edge is not None in {}. Setting both to " + "None".format(self.__class__.__name__) + ) + max_scale_long_edge = None + max_scale_short_edge = None + + super().__init__(opts=opts) + self.min_ratio = min_ratio + self.max_ratio = max_ratio + + self.max_scale_long_edge = max_scale_long_edge + self.max_scale_short_edge = max_scale_short_edge + + self.interpolation = interpolation + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-resize.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.random-resize.max-scale-long-edge", + type=int, + default=None, + help="Max. value along the longest edge. Defaults to None", + ) + group.add_argument( + "--image-augmentation.random-resize.max-scale-short-edge", + type=int, + default=None, + help="Max. value along the shortest edge. Defaults to None.", + ) + + group.add_argument( + "--image-augmentation.random-resize.min-ratio", + type=float, + default=0.5, + help="Min ratio for random resizing. Defaults to 0.5", + ) + group.add_argument( + "--image-augmentation.random-resize.max-ratio", + type=float, + default=2.0, + help="Max ratio for random resizing. Defaults to 2.0", + ) + group.add_argument( + "--image-augmentation.random-resize.interpolation", + type=str, + default="bilinear", + choices=list(INTERPOLATION_MODE_MAP.keys()), + help="Desired interpolation method. Defaults to bilinear.", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + random_ratio = random.uniform(self.min_ratio, self.max_ratio) + + # compute the size + width, height = F.get_image_size(data["image"]) + if self.max_scale_long_edge is not None: + min_hw = min(height, width) + max_hw = max(height, width) + scale_factor = ( + min( + self.max_scale_long_edge / max_hw, + self.max_scale_short_edge / min_hw, + ) + * random_ratio + ) + # resize while maintaining aspect ratio + new_size = int(math.ceil(height * scale_factor)), int( + math.ceil(width * scale_factor) + ) + else: + new_size = int(math.ceil(height * random_ratio)), int( + math.ceil(width * random_ratio) + ) + # new_size should be a tuple of height and width + return _resize_fn(data, size=new_size, interpolation=self.interpolation) + + def __repr__(self) -> str: + return "{}(min_ratio={}, max_ratio={}, interpolation={}, max_long_edge={}, max_short_edge={})".format( + self.__class__.__name__, + self.min_ratio, + self.max_ratio, + self.interpolation, + self.max_scale_long_edge, + self.max_scale_short_edge, + ) + + +@register_transformations(name="random_short_size_resize", type="image_pil") +class RandomShortSizeResize(BaseTransformation): + """ + This class implements random resizing such that shortest side is between specified minimum and maximum values. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts=opts) + short_size_min = getattr( + opts, "image_augmentation.random_short_size_resize.short_side_min", None + ) + short_size_max = getattr( + opts, "image_augmentation.random_short_size_resize.short_side_max", None + ) + max_img_dim = getattr( + opts, "image_augmentation.random_short_size_resize.max_img_dim", None + ) + if short_size_min is None: + logger.error( + "Short side minimum value can't be None in {}".format( + self.__class__.__name__ + ) + ) + if short_size_max is None: + logger.error( + "Short side maximum value can't be None in {}".format( + self.__class__.__name__ + ) + ) + if max_img_dim is None: + logger.error( + "Max. image dimension value can't be None in {}".format( + self.__class__.__name__ + ) + ) + + if short_size_max <= short_size_min: + logger.error( + "Short side maximum value should be >= short side minimum value in {}. Got: {} and {}".format( + self.__class__.__name__, short_size_max, short_size_min + ) + ) + + interpolation = getattr( + opts, "image_augmentation.random_short_size_resize.interpolation", "bicubic" + ) + + self.short_side_min = short_size_min + self.short_side_max = short_size_max + self.max_img_dim = max_img_dim + self.interpolation = interpolation + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-short-size-resize.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.random-short-size-resize.short-side-min", + type=int, + default=None, + help="Minimum value for image's shortest side. Defaults to None.", + ) + group.add_argument( + "--image-augmentation.random-short-size-resize.short-side-max", + type=int, + default=None, + help="Maximum value for image's shortest side. Defaults to None.", + ) + group.add_argument( + "--image-augmentation.random-short-size-resize.interpolation", + type=str, + default="bicubic", + choices=list(INTERPOLATION_MODE_MAP.keys()), + help="Desired interpolation method. Defaults to bicubic", + ) + group.add_argument( + "--image-augmentation.random-short-size-resize.max-img-dim", + type=int, + default=None, + help="Max. image dimension. Defaults to None.", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + short_side = random.randint(self.short_side_min, self.short_side_max) + img_w, img_h = data["image"].size + scale = min( + short_side / min(img_h, img_w), self.max_img_dim / max(img_h, img_w) + ) + img_w = int(img_w * scale) + img_h = int(img_h * scale) + data = _resize_fn(data, size=(img_h, img_w), interpolation=self.interpolation) + return data + + def __repr__(self) -> str: + return "{}(short_side_min={}, short_side_max={}, interpolation={})".format( + self.__class__.__name__, + self.short_side_min, + self.short_side_max, + self.interpolation, + ) + + +@register_transformations(name="random_erasing", type="image_pil") +class RandomErasing(BaseTransformation, T.RandomErasing): + """ + This class randomly selects a region in a tensor and erases its pixels. + See `this paper `_ for details. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + BaseTransformation.__init__(self, opts=opts) + random_erase_p = getattr(opts, "image_augmentation.random_erase.p", 0.5) + T.RandomErasing.__init__(self, p=random_erase_p) + + self.random_erase_p = random_erase_p + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.random-erase.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.random-erase.p", + type=float, + default=0.5, + help="Probability that random erasing operation will be applied. Defaults to 0.5", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + data["image"] = super().forward(data.pop("image")) + return data + + def __repr__(self) -> str: + return "{}(random_erase_p={})".format( + self.__class__.__name__, self.random_erase_p + ) + + +@register_transformations(name="random_gaussian_blur", type="image_pil") +class RandomGaussianBlur(BaseTransformation): + """ + This method randomly blurs the input image. + """ + + def __init__(self, opts, *args, **kwargs): + super().__init__(opts=opts) + self.p = getattr(opts, "image_augmentation.random_gaussian_noise.p", 0.5) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.random-gaussian-noise.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.random-gaussian-noise.p", + type=float, + default=0.5, + help="Probability for applying {}".format(cls.__name__), + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if random.random() < self.p: + img = data.pop("image") + # radius is the standard devaition of the gaussian kernel + img = img.filter(ImageFilter.GaussianBlur(radius=random.random())) + data["image"] = img + return data + + +@register_transformations(name="random_crop", type="image_pil") +class RandomCrop(BaseTransformation): + """ + This method randomly crops an image area. + + .. note:: + If the size of input image is smaller than the desired crop size, the input image is first resized + while maintaining the aspect ratio and then cropping is performed. + """ + + def __init__( + self, + opts, + size: Union[Sequence, int], + ignore_idx: Optional[int] = 255, + *args, + **kwargs + ) -> None: + super().__init__(opts=opts) + self.height, self.width = setup_size(size=size) + self.opts = opts + self.seg_class_max_ratio = getattr( + opts, "image_augmentation.random_crop.seg_class_max_ratio", None + ) + self.ignore_idx = ignore_idx + self.num_repeats = 10 + self.seg_fill = getattr(opts, "image_augmentation.random_crop.mask_fill", 0) + pad_if_needed = getattr( + opts, "image_augmentation.random_crop.pad_if_needed", False + ) + self.if_needed_fn = ( + self._pad_if_needed if pad_if_needed else self._resize_if_needed + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.random-crop.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.random-crop.seg-class-max-ratio", + default=None, + type=float, + help="Max. ratio that single segmentation class can occupy. Defaults to None", + ) + group.add_argument( + "--image-augmentation.random-crop.pad-if-needed", + action="store_true", + help="Pad images if needed. Defaults to False, i.e., resizing will be performed", + ) + group.add_argument( + "--image-augmentation.random-crop.mask-fill", + type=int, + default=255, + help="Value to fill in segmentation mask in case of padding. Defaults to 255. " + "Generally, this value is the same as background or undefined class id.", + ) + return parser + + @staticmethod + def get_params(img_h, img_w, target_h, target_w): + if img_w == target_w and img_h == target_h: + return 0, 0, img_h, img_w + + i = random.randint(0, max(0, img_h - target_h)) + j = random.randint(0, max(0, img_w - target_w)) + return i, j, target_h, target_w + + @staticmethod + def get_params_from_box(boxes, img_h, img_w): + # x, y, w, h + offset = random.randint(20, 50) + start_x = max(0, int(round(np.min(boxes[..., 0]))) - offset) + start_y = max(0, int(round(np.min(boxes[..., 1]))) - offset) + end_x = min(int(round(np.max(boxes[..., 2]))) + offset, img_w) + end_y = min(int(round(np.max(boxes[..., 3]))) + offset, img_h) + + return start_y, start_x, end_y - start_y, end_x - start_x + + def get_params_from_mask(self, data, i, j, h, w): + img_w, img_h = F.get_image_size(data["image"]) + for _ in range(self.num_repeats): + temp_data = _crop_fn( + data=copy.deepcopy(data), top=i, left=j, height=h, width=w + ) + class_labels, cls_count = np.unique( + np.array(temp_data["mask"]), return_counts=True + ) + valid_cls_count = cls_count[class_labels != self.ignore_idx] + + if valid_cls_count.size == 0: + continue + + # compute the ratio of segmentation class with max. pixels to total pixels. + # If the ratio is less than seg_class_max_ratio, then exit the loop + total_valid_pixels = np.sum(valid_cls_count) + max_valid_pixels = np.max(valid_cls_count) + ratio = max_valid_pixels / total_valid_pixels + + if len(cls_count) > 1 and ratio < self.seg_class_max_ratio: + break + i, j, h, w = self.get_params( + img_h=img_h, img_w=img_w, target_h=self.height, target_w=self.width + ) + return i, j, h, w + + def _resize_if_needed(self, data: Dict) -> Dict: + img = data["image"] + + w, h = F.get_image_size(img) + # resize while maintaining the aspect ratio + new_size = min(h + max(0, self.height - h), w + max(0, self.width - w)) + + return _resize_fn( + data, size=new_size, interpolation=T.InterpolationMode.BILINEAR + ) + + def _pad_if_needed(self, data: Dict) -> Dict: + img = data.pop("image") + + w, h = F.get_image_size(img) + new_h = h + max(self.height - h, 0) + new_w = w + max(self.width - w, 0) + + pad_img = Image.new(img.mode, (new_w, new_h), color=0) + pad_img.paste(img, (0, 0)) + data["image"] = pad_img + + if "mask" in data: + mask = data.pop("mask") + pad_mask = Image.new(mask.mode, (new_w, new_h), color=self.seg_fill) + pad_mask.paste(mask, (0, 0)) + data["mask"] = pad_mask + + return data + + def __call__(self, data: Dict) -> Dict: + # box_info + if "box_coordinates" in data: + boxes = data.get("box_coordinates") + # crop the relevant area + image_w, image_h = F.get_image_size(data["image"]) + box_i, box_j, box_h, box_w = self.get_params_from_box( + boxes, image_h, image_w + ) + data = _crop_fn(data, top=box_i, left=box_j, height=box_h, width=box_w) + + data = self.if_needed_fn(data) + + img_w, img_h = F.get_image_size(data["image"]) + i, j, h, w = self.get_params( + img_h=img_h, img_w=img_w, target_h=self.height, target_w=self.width + ) + + if ( + "mask" in data + and self.seg_class_max_ratio is not None + and self.seg_class_max_ratio < 1.0 + ): + i, j, h, w = self.get_params_from_mask(data=data, i=i, j=j, h=h, w=w) + + data = _crop_fn(data=data, top=i, left=j, height=h, width=w) + return data + + def __repr__(self) -> str: + return "{}(size=(h={}, w={}), seg_class_max_ratio={}, seg_fill={})".format( + self.__class__.__name__, + self.height, + self.width, + self.seg_class_max_ratio, + self.seg_fill, + ) + + +@register_transformations(name="to_tensor", type="image_pil") +class ToTensor(BaseTransformation): + """ + This method converts an image into a tensor. + + .. note:: + We do not perform any mean-std normalization. If mean-std normalization is desired, please modify this class. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts=opts) + img_dtype = getattr(opts, "image_augmentation.to_tensor.dtype", "float") + self.img_dtype = torch.float + self.norm_factor = 255 + if img_dtype in ["half", "float16"]: + self.img_dtype = torch.float16 + elif img_dtype in ["uint8"]: + self.img_dtype = torch.uint8 + self.norm_factor = 1 + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser.add_argument( + "--image-augmentation.to-tensor.dtype", + type=str, + default="float", + help="Tensor data type. Default is float", + ) + return parser + + def __repr__(self): + return "{}(dtype={}, norm_factor={})".format( + self.__class__.__name__, self.img_dtype, self.norm_factor + ) + + def __call__(self, data: Dict) -> Dict: + # HWC --> CHW + img = data["image"] + + if F._is_pil_image(img): + # convert PIL image to tensor + img = F.pil_to_tensor(img).contiguous() + + data["image"] = img.to(dtype=self.img_dtype).div(self.norm_factor) + + if "mask" in data: + mask = data.pop("mask") + mask = np.array(mask) + + if len(mask.shape) not in (2, 3): + logger.error( + "Mask needs to be 2- or 3-dimensional. Got: {}".format(mask.shape) + ) + data["mask"] = torch.as_tensor(mask, dtype=torch.long) + + if "box_coordinates" in data: + boxes = data.pop("box_coordinates") + data["box_coordinates"] = torch.as_tensor(boxes, dtype=torch.float) + + if "box_labels" in data: + box_labels = data.pop("box_labels") + data["box_labels"] = torch.as_tensor(box_labels) + + if "instance_mask" in data: + assert "instance_coords" in data + instance_masks = data.pop("instance_mask") + data["instance_mask"] = instance_masks.to(dtype=torch.long) + + instance_coords = data.pop("instance_coords") + data["instance_coords"] = torch.as_tensor( + instance_coords, dtype=torch.float + ) + return data + + +@register_transformations(name="compose", type="image_pil") +class Compose(BaseTransformation): + """ + This method applies a list of transforms in a sequential fashion. + """ + + def __init__(self, opts, img_transforms: List, *args, **kwargs) -> None: + super().__init__(opts=opts) + self.img_transforms = img_transforms + + def __call__(self, data: Dict) -> Dict: + for t in self.img_transforms: + data = t(data) + return data + + def __repr__(self) -> str: + transform_str = ", ".join("\n\t\t\t" + str(t) for t in self.img_transforms) + repr_str = "{}({}\n\t\t)".format(self.__class__.__name__, transform_str) + return repr_str + + +@register_transformations(name="random_order", type="image_pil") +class RandomOrder(BaseTransformation): + """ + This method applies a list of all or few transforms in a random order. + """ + + def __init__(self, opts, img_transforms: List, *args, **kwargs) -> None: + super().__init__(opts=opts) + self.transforms = img_transforms + apply_k_factor = getattr(opts, "image_augmentation.random_order.apply_k", 1.0) + assert ( + 0.0 < apply_k_factor <= 1.0 + ), "--image-augmentation.random-order.apply-k should be > 0 and <= 1" + self.keep_t = int(math.ceil(len(self.transforms) * apply_k_factor)) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--image-augmentation.random-order.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.random-order.apply-k", + type=int, + default=1.0, + help="Apply K percent of transforms randomly. Value between 0 and 1. " + "Defaults to 1 (i.e., apply all transforms in random order).", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + random.shuffle(self.transforms) + for t in self.transforms[: self.keep_t]: + data = t(data) + return data + + def __repr__(self): + transform_str = ", ".join(str(t) for t in self.transforms) + repr_str = "{}(n_transforms={}, t_list=[{}]".format( + self.__class__.__name__, self.keep_t, transform_str + ) + return repr_str + + +@register_transformations(name="rand_augment_timm", type="image_pil") +class RandAugmentTimm(BaseTransformation): + """ + This class implements the `RandAugment data augmentation `_ method, + as described in `ResNet Strikes Back `_ paper + """ + + def __init__(self, opts, *args, **kwargs) -> None: + config_str = getattr( + opts, + "image_augmentation.rand_augment.timm_config_str", + "rand-m9-mstd0.5-inc1", + ) + + super().__init__(opts=opts, *args, **kwargs) + + rand_augment_transform = None + try: + from timm.data.transforms_factory import rand_augment_transform + except ModuleNotFoundError: + logger.error("Please install timm library") + + self.config_str = config_str + self.aug_fn = rand_augment_transform + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.rand-augment.use-timm-library", + action="store_true", + help="Use timm library for randaugment over PyTorch's implementation", + ) + group.add_argument( + "--image-augmentation.rand-augment.timm-config-str", + type=str, + default="rand-m9-mstd0.5-inc1", + help="Number of augmentation transformations to apply sequentially. Defaults to 2.", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if "box_coordinates" in data or "mask" in data or "instance_masks" in data: + logger.error( + "{} is only supported for classification tasks".format( + self.__class__.__name__ + ) + ) + + img = data["image"] + img_size_min = min(img.size) + aa_params = dict( + translate_const=int(img_size_min * 0.45), + img_mean=tuple([128, 128, 128]), + ) + img = self.aug_fn(self.config_str, aa_params)(img) + data["image"] = img + return data + + def __repr__(self) -> str: + return "{}(config_str={})".format(self.__class__.__name__, self.config_str) diff --git a/Adaptive Frequency Filters/data/transforms/image_torch.py b/Adaptive Frequency Filters/data/transforms/image_torch.py new file mode 100644 index 0000000..16739ab --- /dev/null +++ b/Adaptive Frequency Filters/data/transforms/image_torch.py @@ -0,0 +1,248 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import math +from typing import Dict +import argparse +import torch +from torchvision.transforms import functional as F +from torch.nn import functional as F_torch + +from utils import logger + +from . import register_transformations, BaseTransformation + + +# Copied from PyTorch Torchvision +@register_transformations(name="random_mixup", type="image_torch") +class RandomMixup(BaseTransformation): + """ + Given a batch of input images and labels, this class randomly applies the + `Mixup transformation `_ + + Args: + num_classes (int): Number of classes in the dataset + """ + + def __init__(self, opts, num_classes: int, *args, **kwargs) -> None: + super().__init__(opts=opts, *args, **kwargs) + alpha = getattr(opts, "image_augmentation.mixup.alpha", 1.0) + assert ( + num_classes > 0 + ), "Please provide a valid positive value for the num_classes." + assert alpha > 0, "Alpha param can't be zero." + + self.num_classes = num_classes + self.p = getattr(opts, "image_augmentation.mixup.p", 0.5) + self.alpha = alpha + self.inplace = getattr(opts, "image_augmentation.mixup.inplace", False) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.mixup.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--image-augmentation.mixup.alpha", + type=float, + default=1.0, + help="Alpha for MixUp augmentation. Defaults to 1.0", + ) + group.add_argument( + "--image-augmentation.mixup.p", + type=float, + default=0.5, + help="Probability for applying mixup augmentation. Defaults to 0.5", + ) + group.add_argument( + "--image-augmentation.mixup.inplace", + action="store_true", + default=False, + help="Apply Mixup augmentation inplace. Defaults to False.", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if torch.rand(1).item() >= self.p: + return data + + image_tensor, target_tensor = data.pop("samples"), data.pop("targets") + + if image_tensor.ndim != 4: + logger.error(f"Batch ndim should be 4. Got {image_tensor.ndim}") + if target_tensor.ndim != 1: + logger.error(f"Target ndim should be 1. Got {target_tensor.ndim}") + if not image_tensor.is_floating_point(): + logger.error( + f"Batch dtype should be a float tensor. Got {image_tensor.dtype}." + ) + if target_tensor.dtype != torch.int64: + logger.error( + f"Target dtype should be torch.int64. Got {target_tensor.dtype}" + ) + + if not self.inplace: + image_tensor = image_tensor.clone() + target_tensor = target_tensor.clone() + + if target_tensor.ndim == 1: + target_tensor = F_torch.one_hot( + target_tensor, num_classes=self.num_classes + ).to(dtype=image_tensor.dtype) + + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_rolled = image_tensor.roll(1, 0) + target_rolled = target_tensor.roll(1, 0) + + # Implemented as on mixup paper, page 3. + lambda_param = float( + torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] + ) + batch_rolled.mul_(1.0 - lambda_param) + image_tensor.mul_(lambda_param).add_(batch_rolled) + + target_rolled.mul_(1.0 - lambda_param) + target_tensor.mul_(lambda_param).add_(target_rolled) + + data["samples"] = image_tensor + data["targets"] = target_tensor + + return data + + def __repr__(self) -> str: + return "{}(num_classes={}, p={}, alpha={}, inplace={})".format( + self.__class__.__name__, self.num_classes, self.p, self.alpha, self.inplace + ) + + +@register_transformations(name="random_cutmix", type="image_torch") +class RandomCutmix(BaseTransformation): + """ + Given a batch of input images and labels, this class randomly applies the + `CutMix transformation `_ + + Args: + num_classes (int): Number of classes in the dataset + """ + + def __init__(self, opts, num_classes: int, *args, **kwargs) -> None: + super().__init__(opts=opts, *args, **kwargs) + alpha = getattr(opts, "image_augmentation.cutmix.alpha", 1.0) + assert ( + num_classes > 0 + ), "Please provide a valid positive value for the num_classes." + assert alpha > 0, "Alpha param can't be zero." + + self.num_classes = num_classes + self.p = getattr(opts, "image_augmentation.cutmix.p", 0.5) + self.alpha = alpha + self.inplace = getattr(opts, "image_augmentation.cutmix.inplace", False) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--image-augmentation.cutmix.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + + group.add_argument( + "--image-augmentation.cutmix.alpha", + type=float, + default=1.0, + help="Alpha for cutmix augmentation. Defaults to 1.0", + ) + group.add_argument( + "--image-augmentation.cutmix.p", + type=float, + default=0.5, + help="Probability for applying cutmix augmentation. Defaults to 0.5", + ) + group.add_argument( + "--image-augmentation.cutmix.inplace", + action="store_true", + default=False, + help="Apply cutmix operation inplace. Defaults to False", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + if torch.rand(1).item() >= self.p: + return data + + image_tensor, target_tensor = data.pop("samples"), data.pop("targets") + + if image_tensor.ndim != 4: + logger.error(f"Batch ndim should be 4. Got {image_tensor.ndim}") + if target_tensor.ndim != 1: + logger.error(f"Target ndim should be 1. Got {target_tensor.ndim}") + if not image_tensor.is_floating_point(): + logger.error( + f"Batch dtype should be a float tensor. Got {image_tensor.dtype}." + ) + if target_tensor.dtype != torch.int64: + logger.error( + f"Target dtype should be torch.int64. Got {target_tensor.dtype}" + ) + + if not self.inplace: + image_tensor = image_tensor.clone() + target_tensor = target_tensor.clone() + + if target_tensor.ndim == 1: + target_tensor = F_torch.one_hot( + target_tensor, num_classes=self.num_classes + ).to(dtype=image_tensor.dtype) + + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_rolled = image_tensor.roll(1, 0) + target_rolled = target_tensor.roll(1, 0) + + # Implemented as on cutmix paper, page 12 (with minor corrections on typos). + lambda_param = float( + torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] + ) + W, H = F.get_image_size(image_tensor) + + r_x = torch.randint(W, (1,)) + r_y = torch.randint(H, (1,)) + + r = 0.5 * math.sqrt(1.0 - lambda_param) + r_w_half = int(r * W) + r_h_half = int(r * H) + + x1 = int(torch.clamp(r_x - r_w_half, min=0)) + y1 = int(torch.clamp(r_y - r_h_half, min=0)) + x2 = int(torch.clamp(r_x + r_w_half, max=W)) + y2 = int(torch.clamp(r_y + r_h_half, max=H)) + + image_tensor[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] + lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) + + target_rolled.mul_(1.0 - lambda_param) + target_tensor.mul_(lambda_param).add_(target_rolled) + + data["samples"] = image_tensor + data["targets"] = target_tensor + + return data + + def __repr__(self) -> str: + return "{}(num_classes={}, p={}, alpha={}, inplace={})".format( + self.__class__.__name__, self.num_classes, self.p, self.alpha, self.inplace + ) diff --git a/Adaptive Frequency Filters/data/transforms/utils.py b/Adaptive Frequency Filters/data/transforms/utils.py new file mode 100644 index 0000000..746fe4f --- /dev/null +++ b/Adaptive Frequency Filters/data/transforms/utils.py @@ -0,0 +1,48 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from typing import Any +import numpy as np + + +def setup_size(size: Any, error_msg="Need a tuple of length 2"): + if size is None: + raise ValueError("Size can't be None") + + if isinstance(size, int): + return size, size + elif isinstance(size, (list, tuple)) and len(size) == 1: + return size[0], size[0] + + if len(size) != 2: + raise ValueError(error_msg) + + return size + + +def intersect(box_a, box_b): + """Computes the intersection between box_a and box_b""" + max_xy = np.minimum(box_a[:, 2:], box_b[2:]) + min_xy = np.maximum(box_a[:, :2], box_b[:2]) + inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf) + return inter[:, 0] * inter[:, 1] + + +def jaccard_numpy(box_a: np.ndarray, box_b: np.ndarray): + """ + Computes the intersection of two boxes. + Args: + box_a (np.ndarray): Boxes of shape [Num_boxes_A, 4] + box_b (np.ndarray): Box osf shape [Num_boxes_B, 4] + + Returns: + intersection over union scores. Shape is [box_a.shape[0], box_a.shape[1]] + """ + inter = intersect(box_a, box_b) + area_a = (box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1]) # [A,B] + area_b = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]) # [A,B] + union = area_a + area_b - inter + return inter / union # [A,B] diff --git a/Adaptive Frequency Filters/data/transforms/video.py b/Adaptive Frequency Filters/data/transforms/video.py new file mode 100644 index 0000000..e29175e --- /dev/null +++ b/Adaptive Frequency Filters/data/transforms/video.py @@ -0,0 +1,609 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import random +import torch +import math +import argparse +from typing import Sequence, Dict, Any, Union, Tuple, List, Optional +from torch.nn import functional as F + +from utils import logger + +from . import register_transformations, BaseTransformation +from .utils import * + + +SUPPORTED_PYTORCH_INTERPOLATIONS = ["nearest", "bilinear", "bicubic"] + + +def _check_interpolation(interpolation): + if interpolation not in SUPPORTED_PYTORCH_INTERPOLATIONS: + inter_str = "Supported interpolation modes are:" + for i, j in enumerate(SUPPORTED_PYTORCH_INTERPOLATIONS): + inter_str += "\n\t{}: {}".format(i, j) + logger.error(inter_str) + return interpolation + + +def _crop_fn(data: Dict, i: int, j: int, h: int, w: int): + img = data["image"] + if not isinstance(img, torch.Tensor) and img.dim() != 4: + logger.error( + "Cropping requires 4-d tensor of shape NCHW or CNHW. Got {}-dimensional tensor".format( + img.dim() + ) + ) + + crop_image = img[..., i : i + h, j : j + w] + data["image"] = crop_image + + mask = data.get("mask", None) + if mask is not None: + crop_mask = mask[..., i : i + h, j : j + w] + data["mask"] = crop_mask + return data + + +def _resize_fn( + data: Dict, size: Union[Sequence, int], interpolation: Optional[str] = "bilinear" +): + img = data["image"] + + if isinstance(size, Sequence) and len(size) == 2: + size_h, size_w = size[0], size[1] + elif isinstance(size, int): + h, w = img.shape[-2:] + if (w <= h and w == size) or (h <= w and h == size): + return data + + if w < h: + size_h = int(size * h / w) + + size_w = size + else: + size_w = int(size * w / h) + size_h = size + else: + raise TypeError( + "Supported size args are int or tuple of length 2. Got inappropriate size arg: {}".format( + size + ) + ) + if isinstance(interpolation, str): + interpolation = _check_interpolation(interpolation) + img = F.interpolate( + input=img, + size=(size_w, size_h), + mode=interpolation, + align_corners=True if interpolation != "nearest" else None, + ) + data["image"] = img + + mask = data.get("mask", None) + if mask is not None: + mask = F.interpolate(input=mask, size=(size_w, size_h), mode="nearest") + data["mask"] = mask + + return data + + +def _check_rgb_video_tensor(clip): + if not isinstance(clip, torch.FloatTensor) or clip.dim() != 4: + logger.error( + "Video clip is either not an instance of FloatTensor or it is not a 4-d tensor (NCHW or CNHW)" + ) + + +@register_transformations(name="to_tensor", type="video") +class ToTensor(BaseTransformation): + """ + This method converts an image into a tensor. + + .. note:: + We do not perform any mean-std normalization. If mean-std normalization is desired, please modify this class. + """ + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts=opts) + + def __call__(self, data: Dict) -> Dict: + # [C, N, H, W] + clip = data["image"] + if not isinstance(clip, torch.Tensor): + clip = torch.from_numpy(clip) + clip = clip.float() + + _check_rgb_video_tensor(clip=clip) + + # normalize between 0 and 1 + clip = torch.div(clip, 255.0) + data["image"] = clip + return data + + +@register_transformations(name="random_resized_crop", type="video") +class RandomResizedCrop(BaseTransformation): + """ + This class crops a random portion of an image and resize it to a given size. + """ + + def __init__(self, opts, size: Union[Tuple, int], *args, **kwargs) -> None: + interpolation = getattr( + opts, "video_augmentation.random_resized_crop.interpolation", "bilinear" + ) + scale = getattr( + opts, "video_augmentation.random_resized_crop.scale", (0.08, 1.0) + ) + ratio = getattr( + opts, + "video_augmentation.random_resized_crop.aspect_ratio", + (3.0 / 4.0, 4.0 / 3.0), + ) + + if not isinstance(scale, Sequence) or ( + isinstance(scale, Sequence) + and len(scale) != 2 + and 0.0 <= scale[0] < scale[1] + ): + logger.error( + "--video-augmentation.random-resized-crop.scale should be a tuple of length 2 " + "such that 0.0 <= scale[0] < scale[1]. Got: {}".format(scale) + ) + + if not isinstance(ratio, Sequence) or ( + isinstance(ratio, Sequence) + and len(ratio) != 2 + and 0.0 < ratio[0] < ratio[1] + ): + logger.error( + "--video-augmentation.random-resized-crop.aspect-ratio should be a tuple of length 2 " + "such that 0.0 < ratio[0] < ratio[1]. Got: {}".format(ratio) + ) + + ratio = (round(ratio[0], 3), round(ratio[1], 3)) + + super().__init__(opts=opts) + + self.scale = scale + self.size = setup_size(size=size) + + self.interpolation = _check_interpolation(interpolation) + self.ratio = ratio + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--video-augmentation.random-resized-crop.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--video-augmentation.random-resized-crop.interpolation", + type=str, + default="bilinear", + choices=SUPPORTED_PYTORCH_INTERPOLATIONS, + help="Desired interpolation method. Defaults to bilinear", + ) + group.add_argument( + "--video-augmentation.random-resized-crop.scale", + type=tuple, + default=(0.08, 1.0), + help="Specifies the lower and upper bounds for the random area of the crop, before resizing." + " The scale is defined with respect to the area of the original image. Defaults to " + "(0.08, 1.0)", + ) + group.add_argument( + "--video-augmentation.random-resized-crop.aspect-ratio", + type=float or tuple, + default=(3.0 / 4.0, 4.0 / 3.0), + help="lower and upper bounds for the random aspect ratio of the crop, before resizing. " + "Defaults to (3./4., 4./3.)", + ) + return parser + + def get_params(self, height: int, width: int) -> (int, int, int, int): + area = height * width + for _ in range(10): + target_area = random.uniform(*self.scale) * area + log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = (1.0 * width) / height + if in_ratio < min(self.ratio): + w = width + h = int(round(w / min(self.ratio))) + elif in_ratio > max(self.ratio): + h = height + w = int(round(h * max(self.ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + def __call__(self, data: Dict) -> Dict: + clip = data["image"] + _check_rgb_video_tensor(clip=clip) + + height, width = clip.shape[-2:] + + i, j, h, w = self.get_params(height=height, width=width) + data = _crop_fn(data=data, i=i, j=j, h=h, w=w) + return _resize_fn(data=data, size=self.size, interpolation=self.interpolation) + + def __repr__(self) -> str: + return "{}(scale={}, ratio={}, interpolation={})".format( + self.__class__.__name__, self.scale, self.ratio, self.interpolation + ) + + +@register_transformations(name="random_short_side_resize_crop", type="video") +class RandomShortSizeResizeCrop(BaseTransformation): + """ + This class first randomly resizes the input video such that shortest side is between specified minimum and + maximum values, adn then crops a desired size video. + + .. note:: + This class assumes that the video size after resizing is greater than or equal to the desired size. + """ + + def __init__(self, opts, size: Union[Tuple, int], *args, **kwargs) -> None: + interpolation = getattr( + opts, + "video_augmentation.random_short_side_resize_crop.interpolation", + "bilinear", + ) + short_size_min = getattr( + opts, + "video_augmentation.random_short_side_resize_crop.short_side_min", + None, + ) + short_size_max = getattr( + opts, + "video_augmentation.random_short_side_resize_crop.short_side_max", + None, + ) + + if short_size_min is None: + logger.error( + "Short side minimum value can't be None in {}".format( + self.__class__.__name__ + ) + ) + if short_size_max is None: + logger.error( + "Short side maximum value can't be None in {}".format( + self.__class__.__name__ + ) + ) + + if short_size_max <= short_size_min: + logger.error( + "Short side maximum value should be >= short side minimum value in {}. Got: {} and {}".format( + self.__class__.__name__, short_size_max, short_size_min + ) + ) + + super().__init__(opts=opts) + self.short_side_min = short_size_min + self.size = size + self.short_side_max = short_size_max + self.interpolation = _check_interpolation(interpolation) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--video-augmentation.random-short-side-resize-crop.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--video-augmentation.random-short-side-resize-crop.interpolation", + type=str, + default="bilinear", + choices=SUPPORTED_PYTORCH_INTERPOLATIONS, + help="Desired interpolation method. Defaults to bilinear", + ) + group.add_argument( + "--video-augmentation.random-short-side-resize-crop.short-side-min", + type=int, + default=None, + help="Minimum value for video's shortest side. Defaults to None.", + ) + group.add_argument( + "--video-augmentation.random-short-side-resize-crop.short-side-max", + type=int, + default=None, + help="Maximum value for video's shortest side. Defaults to None.", + ) + return parser + + def get_params(self, height, width) -> Tuple[int, int, int, int]: + th, tw = self.size + + if width == tw and height == th: + return 0, 0, height, width + + i = random.randint(0, height - th) + j = random.randint(0, width - tw) + return i, j, th, tw + + def __call__(self, data: Dict) -> Dict: + short_dim = random.randint(self.short_side_max, self.short_side_max) + # resize the video so that shorter side is short_dim + data = _resize_fn(data, size=short_dim, interpolation=self.interpolation) + + clip = data["image"] + _check_rgb_video_tensor(clip=clip) + height, width = clip.shape[-2:] + i, j, h, w = self.get_params(height=height, width=width) + # crop the video + return _crop_fn(data=data, i=i, j=j, h=h, w=w) + + def __repr__(self) -> str: + return "{}(size={}, short_size_range=({}, {}), interpolation={})".format( + self.__class__.__name__, + self.size, + self.short_side_min, + self.short_side_max, + self.interpolation, + ) + + +@register_transformations(name="random_crop", type="video") +class RandomCrop(BaseTransformation): + """ + This method randomly crops a video area. + + .. note:: + This class assumes that the input video size is greater than or equal to the desired size. + """ + + def __init__(self, opts, size: Union[Tuple, int], *args, **kwargs) -> None: + size = setup_size(size=size) + super().__init__(opts=opts) + self.size = size + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--video-augmentation.random-crop.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + return parser + + def get_params(self, height, width) -> Tuple[int, int, int, int]: + th, tw = self.size + + if width == tw and height == th: + return 0, 0, height, width + + i = random.randint(0, height - th) + j = random.randint(0, width - tw) + return i, j, th, tw + + def __call__(self, data: Dict) -> Dict: + clip = data["image"] + _check_rgb_video_tensor(clip=clip) + height, width = clip.shape[-2:] + i, j, h, w = self.get_params(height=height, width=width) + return _crop_fn(data=data, i=i, j=j, h=h, w=w) + + def __repr__(self) -> str: + return "{}(size={})".format(self.__class__.__name__, self.size) + + +@register_transformations(name="random_horizontal_flip", type="video") +class RandomHorizontalFlip(BaseTransformation): + """ + This class implements random horizontal flipping method + """ + + def __init__(self, opts, *args, **kwargs) -> None: + p = getattr(opts, "video_augmentation.random_horizontal_flip.p", 0.5) + super().__init__(opts=opts) + self.p = p + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--video-augmentation.random-horizontal-flip.enable", + action="store_true", + help="use {}. This flag is useful when you want to study the effect of different " + "transforms.".format(cls.__name__), + ) + group.add_argument( + "--video-augmentation.random-horizontal-flip.p", + type=float, + default=0.5, + help="Probability for random horizontal flip", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + + if random.random() <= self.p: + clip = data["image"] + _check_rgb_video_tensor(clip=clip) + clip = torch.flip(clip, dims=[-1]) + data["image"] = clip + + mask = data.get("mask", None) + if mask is not None: + mask = torch.flip(mask, dims=[-1]) + data["mask"] = mask + + return data + + +@register_transformations(name="center_crop", type="video") +class CenterCrop(BaseTransformation): + """ + This class implements center cropping method. + + .. note:: + This class assumes that the input size is greater than or equal to the desired size. + """ + + def __init__(self, opts, size: Sequence or int, *args, **kwargs) -> None: + super().__init__(opts=opts) + if isinstance(size, Sequence) and len(size) == 2: + self.height, self.width = size[0], size[1] + elif isinstance(size, Sequence) and len(size) == 1: + self.height = self.width = size[0] + elif isinstance(size, int): + self.height = self.width = size + else: + logger.error("Scale should be either an int or tuple of ints") + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--video-augmentation.center-crop.enable", + action="store_true", + help="use center cropping", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + height, width = data["image"].shape[-2:] + i = (height - self.height) // 2 + j = (width - self.width) // 2 + return _crop_fn(data=data, i=i, j=j, h=self.height, w=self.width) + + def __repr__(self) -> str: + return "{}(size=(h={}, w={}))".format( + self.__class__.__name__, self.height, self.width + ) + + +@register_transformations(name="resize", type="video") +class Resize(BaseTransformation): + """ + This class implements resizing operation. + + .. note:: + Two possible modes for resizing. + 1. Resize while maintaining aspect ratio. To enable this option, pass int as a size + 2. Resize to a fixed size. To enable this option, pass a tuple of height and width as a size + """ + + def __init__(self, opts, *args, **kwargs) -> None: + size = getattr(opts, "video_augmentation.resize.size", None) + if size is None: + logger.error("Size can not be None in {}".format(self.__class__.__name__)) + + # Possible modes. + # 1. Resize while maintaining aspect ratio. To enable this option, pass int as a size + # 2. Resize to a fixed size. To enable this option, pass a tuple of height and width as a size + + if isinstance(size, Sequence) and len(size) > 2: + logger.error( + "The length of size should be either 1 or 2 in {}".format( + self.__class__.__name__ + ) + ) + + interpolation = getattr( + opts, "video_augmentation.resize.interpolation", "bilinear" + ) + super().__init__(opts=opts) + + self.size = size + self.interpolation = _check_interpolation(interpolation) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + + group.add_argument( + "--video-augmentation.resize.enable", + action="store_true", + help="use fixed resizing", + ) + + group.add_argument( + "--video-augmentation.resize.interpolation", + type=str, + default="bilinear", + choices=SUPPORTED_PYTORCH_INTERPOLATIONS, + help="Interpolation for resizing. Default is bilinear", + ) + group.add_argument( + "--video-augmentation.resize.size", + type=int, + nargs="+", + default=None, + help="Resize video to the specified size. If int is passed, then shorter side is resized" + "to the specified size and longest side is resized while maintaining aspect ratio." + "Defaults to None.", + ) + return parser + + def __call__(self, data: Dict) -> Dict: + return _resize_fn(data=data, size=self.size, interpolation=self.interpolation) + + def __repr__(self): + return "{}(size={}, interpolation={})".format( + self.__class__.__name__, self.size, self.interpolation + ) + + +@register_transformations(name="compose", type="video") +class Compose(BaseTransformation): + """ + This method applies a list of transforms in a sequential fashion. + """ + + def __init__(self, opts, video_transforms: List, *args, **kwargs) -> None: + super().__init__(opts=opts) + self.video_transforms = video_transforms + + def __call__(self, data: Dict) -> Dict: + for t in self.video_transforms: + data = t(data) + return data + + def __repr__(self) -> str: + transform_str = ", ".join("\n\t\t\t" + str(t) for t in self.video_transforms) + repr_str = "{}({})".format(self.__class__.__name__, transform_str) + return repr_str diff --git a/Adaptive Frequency Filters/data/video_reader/__init__.py b/Adaptive Frequency Filters/data/video_reader/__init__.py new file mode 100644 index 0000000..ebc834b --- /dev/null +++ b/Adaptive Frequency Filters/data/video_reader/__init__.py @@ -0,0 +1,115 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import importlib +import argparse +from typing import Optional + +from utils.ddp_utils import is_master +from utils import logger + +from .base_video_reader import PyAVBaseReader + + +VIDEO_READER_REGISTRY = {} + + +def register_video_reader(name): + def register_video_reader_class(cls): + if name in VIDEO_READER_REGISTRY: + raise ValueError( + "Cannot register duplicate video reader class ({})".format(name) + ) + + if not issubclass(cls, PyAVBaseReader): + raise ValueError( + "Video reader ({}: {}) must extend PyAVBaseReader".format( + name, cls.__name__ + ) + ) + + VIDEO_READER_REGISTRY[name] = cls + return cls + + return register_video_reader_class + + +def supported_video_reader_str(video_reader_name): + supp_list = list(VIDEO_READER_REGISTRY.keys()) + supp_str = "Video reader ({}) is not yet supported. \n Supported video readers are:".format( + video_reader_name + ) + + for i, vr_name in enumerate(supp_list): + supp_str += "{} \t".format(vr_name) + logger.error(supp_str) + + +def general_video_reader_args(parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Video reader", description="Arguments related to video reader" + ) + group.add_argument( + "--video-reader.name", + type=str, + default="pyav_standard", + help="Name of video reader", + ) + group.add_argument( + "--video-reader.fast-video-decoding", + action="store_true", + help="Multi-threaded fast video decoding using pyav", + ) + group.add_argument( + "--video-reader.frame-stack-format", + type=str, + default="sequence_first", + choices=["sequence_first", "channel_first"], + help="Sequence first (NCHW) or channel first (CNHW) format for stacking video frames", + ) + return parser + + +def arguments_video_reader(parser: argparse.ArgumentParser): + parser = general_video_reader_args(parser=parser) + + # add video reader specific arguments + for k, v in VIDEO_READER_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + + return parser + + +def get_video_reader(opts, is_training: Optional[bool] = False, *args, **kwargs): + vr_name = getattr(opts, "video_reader.name", "pyav_standard") + + is_master_node = is_master(opts) + video_reader = None + if vr_name in VIDEO_READER_REGISTRY: + video_reader = VIDEO_READER_REGISTRY[vr_name]( + opts=opts, is_training=is_training + ) + else: + supported_video_reader_str(video_reader_name=vr_name) + + if is_master_node: + logger.log("Video reader details: ") + print("{}".format(video_reader)) + return video_reader + + +# automatically import video readers +video_reader_dir = os.path.dirname(__file__) +for file in os.listdir(video_reader_dir): + path = os.path.join(video_reader_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + vr_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("data.video_reader." + vr_name) diff --git a/Adaptive Frequency Filters/data/video_reader/base_video_reader.py b/Adaptive Frequency Filters/data/video_reader/base_video_reader.py new file mode 100644 index 0000000..e0f2a38 --- /dev/null +++ b/Adaptive Frequency Filters/data/video_reader/base_video_reader.py @@ -0,0 +1,234 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import copy +from typing import Optional, Any, List +import torch +import numpy as np +import av +from torch import Tensor +import random +import argparse +from PIL import Image +from torchvision.transforms import functional as F_vision + +from utils import logger +from ..transforms import image_pil as T + + +class PyAVBaseReader(object): + """ + PyAv video reader + + Args: + opts: command line arguments + is_training (Optional[bool]): Training or validation mode. Default: `False` + """ + + def __init__(self, opts, is_training: Optional[bool] = False, *args, **kwargs): + super().__init__() + self.fast_decoding = getattr(opts, "video_reader.fast_video_decoding", False) + self.frame_stack_format = getattr( + opts, "video_reader.frame_stack_format", "sequence_first" + ) + self.stack_frame_dim = 1 if self.frame_stack_format == "channel_first" else 0 + + self.frame_transforms = ( + self._frame_transform(opts=opts) if is_training else None + ) + self.random_erase_transform = ( + self._random_erase_transform(opts=opts) if is_training else None + ) + + self.frame_transforms_str = "" + if self.frame_transforms is not None: + self.frame_transforms_str += "\t {}".format( + self.frame_transforms.__repr__() + ) + if self.random_erase_transform is not None: + self.frame_transforms_str += "\t {}".format( + self.random_erase_transform.__repr__() + ) + + self.num_frames_cache = dict() + + @staticmethod + def _frame_transform(opts): + auto_augment = getattr(opts, "image_augmentation.auto_augment.enable", False) + rand_augment = getattr(opts, "image_augmentation.rand_augment.enable", False) + + if auto_augment and rand_augment: + logger.warning( + "AutoAugment and RandAugment are mutually exclusive. Use either of them, but not both" + ) + elif auto_augment: + return T.AutoAugment(opts=opts) + elif rand_augment: + return T.RandAugment(opts=opts) + return None + + @staticmethod + def _random_erase_transform(opts): + random_erase = getattr(opts, " image_augmentation.random_erase.enable", False) + if random_erase: + return T.RandomErasing(opts=opts) + return None + + def __repr__(self): + return "{}(\n\tfast_decoding={}\n\tframe_stack_format={}\n)".format( + self.__class__.__name__, self.fast_decoding, self.frame_stack_format + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + return parser + + def check_video(self, filename: str) -> bool: + try: + # Adapted from basic demo: https://pyav.org/docs/stable/#basic-demo + with av.open(filename) as container: + # Decode the first video channel. + for frame in container.decode(video=0): + frame_idx = frame.index + break + return True + except Exception as e: + return False + + def read_video(self, filename: str, *args, **kwargs) -> Any: + raise NotImplementedError + + def num_frames(self, filename: str) -> int: + if filename in self.num_frames_cache: + return self.num_frames_cache[filename] + else: + total_frames = 0 + with av.open(filename) as container: + total_frames = container.streams.video[0].frames + self.num_frames_cache[filename] = total_frames + return total_frames + + def frame_to_tensor(self, frame): + frame_np = frame.to_ndarray(format="rgb24") + if self.frame_transforms is not None: + # + frame_pil = Image.fromarray(frame_np) + frame_pil = self.frame_transforms({"image": frame_pil})["image"] + frame_np = np.array(frame_pil) + + frame_np = frame_np.transpose(2, 0, 1) + frame_np = np.ascontiguousarray(frame_np) + # [C, H, W] + frame_torch = torch.from_numpy(frame_np) + + # normalize the frame between 0 and 1 + frame_torch = frame_torch.div(255.0) + + # apply random erase transform + if self.random_erase_transform is not None: + frame_torch = self.random_erase_transform({"image": frame_torch})["image"] + + return frame_torch + + @staticmethod + def random_sampling( + desired_frames: int, total_frames: int, n_clips: int, *args, **kwargs + ) -> List: + # divide the video into K clips + try: + interval = ( + desired_frames if total_frames >= desired_frames * (n_clips + 1) else 1 + ) + # The range of start Id is between [0, total_frames - n_desired_frames] + temp = max(0, min(total_frames - desired_frames, total_frames)) + start_ids = sorted( + random.sample(population=range(0, temp, interval), k=n_clips) + ) + # 30 frames and 120 frames in 1s and 4s videos @ 30 FPS, respectively + # The end_id is randomly selected between start_id + 30 and start_id + 120 + end_ids = [ + min( + max(s_id + random.randint(30, 120), s_id + desired_frames), + total_frames - 1, + ) + for s_id in start_ids + ] + except: + # fall back to uniform + video_clip_ids = np.linspace( + 0, total_frames - 1, n_clips + 1, dtype=int + ).tolist() + + start_ids = video_clip_ids[:-1] + end_ids = video_clip_ids[1:] + + frame_ids = [] + for start_idx, end_idx in zip(start_ids, end_ids): + try: + clip_frame_ids = sorted( + random.sample( + population=range(start_idx, end_idx), k=desired_frames + ) + ) + except: + # sample with repetition + clip_frame_ids = np.linspace( + start=start_idx, stop=end_idx - 1, num=desired_frames, dtype=int + ).tolist() + frame_ids.extend(clip_frame_ids) + return frame_ids + + @staticmethod + def uniform_sampling( + desired_frames: int, total_frames: int, n_clips: int, *args, **kwargs + ): + video_clip_ids = np.linspace( + 0, total_frames - 1, n_clips + 1, dtype=int + ).tolist() + start_ids = video_clip_ids[:-1] + end_ids = video_clip_ids[1:] + + frame_ids = [] + for start_idx, end_idx in zip(start_ids, end_ids): + clip_frame_ids = np.linspace( + start=start_idx, stop=end_idx - 1, num=desired_frames, dtype=int + ).tolist() + frame_ids.extend(clip_frame_ids) + return frame_ids + + def convert_to_clips(self, video: torch.Tensor, n_clips: int): + # video is [N, C, H, W] or [C, N, H, W] + video_clips = torch.chunk(video, chunks=n_clips, dim=self.stack_frame_dim) + video_clips = torch.stack(video_clips, dim=0) + # video_clips is [T, n, C, H, W] or [T, C, n, H, W] + return video_clips + + def process_video( + self, + vid_filename: str, + n_frames_per_clip: Optional[int] = -1, + clips_per_video: Optional[int] = 1, + video_transform_fn: Optional = None, + is_training: Optional[bool] = False, + ): + raise NotImplementedError + + def dummy_video( + self, clips_per_video: int, n_frames_to_sample: int, height: int, width: int + ): + + # [K, C, N, H, W] or # [K, N, C, H, W] + # K --> number of clips, C --> Image channels, N --> Number of frames per clip, H --> Height, W --> Width + tensor_size = ( + (clips_per_video, 3, n_frames_to_sample, height, width) + if self.frame_stack_format == "channel_first" + else (clips_per_video, n_frames_to_sample, 3, height, width) + ) + + input_video = torch.zeros( + size=tensor_size, dtype=torch.float32, device=torch.device("cpu") + ) + return input_video diff --git a/Adaptive Frequency Filters/data/video_reader/default_video_reader.py b/Adaptive Frequency Filters/data/video_reader/default_video_reader.py new file mode 100644 index 0000000..e97fcd0 --- /dev/null +++ b/Adaptive Frequency Filters/data/video_reader/default_video_reader.py @@ -0,0 +1,132 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import copy +from typing import Optional, List +import torch +import av +from torch import Tensor + +from . import PyAVBaseReader, register_video_reader + + +@register_video_reader(name="pyav_default") +class PyAVDefaultReader(PyAVBaseReader): + """ + Default PyAv video reader + + Args: + opts: command line arguments + is_training (Optional[bool]): Training or validation mode. Default: `False` + """ + + def __init__( + self, opts, is_training: Optional[bool] = False, *args, **kwargs + ) -> None: + super().__init__(opts=opts, is_training=is_training, *args, **kwargs) + + def read_video( + self, filename: str, frame_ids: Optional[List] = None, *args, **kwargs + ) -> Optional[Tensor]: + + try: + if frame_ids is None: + return None + + # Check basic demo for Pyav usage: https://pyav.org/docs/stable/#basic-demo + with av.open(filename) as container: + stream = container.streams.video[0] + + if self.fast_decoding: + stream.thread_type = "AUTO" + + video_frames = [] + for frame in container.decode(video=0): + frame_idx = frame.index + if frame_idx in frame_ids: + # using PIL so that we can apply wide range of augmentations on Frame, such as RandAug + frame_torch = self.frame_to_tensor(frame) + # check for duplicate frame ids + n_duplicate_frames = max(1, frame_ids.count(frame_idx)) + + video_frames.extend( + [copy.deepcopy(frame_torch)] * n_duplicate_frames + ) + + # [C, H, W] x N --> [N, C, H, W] or [C, N, H, W] + if len(video_frames) == len(frame_ids): + video_frames = torch.stack(video_frames, dim=self.stack_frame_dim) + return video_frames + elif 0 < len(video_frames) < len(frame_ids): + n_delta_frames = len(frame_ids) - len(video_frames) + # add black frames + delta_frame = [torch.zeros_like(video_frames[-1])] * n_delta_frames + video_frames.extend(delta_frame) + + video_frames = torch.stack(video_frames, dim=self.stack_frame_dim) + return video_frames + else: + return None + except av.AVError as ave_error: + return None + + def process_video( + self, + vid_filename: str, + n_frames_per_clip: Optional[int] = -1, + clips_per_video: Optional[int] = 1, + video_transform_fn: Optional = None, + is_training: Optional[bool] = False, + *args, + **kwargs + ): + sampling_method = self.random_sampling if is_training else self.uniform_sampling + + total_frames = self.num_frames(filename=vid_filename) + total_frames_to_sample = n_frames_per_clip * clips_per_video + if n_frames_per_clip < 1: + n_frames_per_clip = total_frames + total_frames_to_sample = total_frames + + frame_ids = sampling_method( + desired_frames=n_frames_per_clip, + total_frames=total_frames, + n_clips=clips_per_video, + ) + + # [N, C, H, W] or [C, N, H, W] + torch_video = self.read_video(filename=vid_filename, frame_ids=frame_ids) + + if isinstance(torch_video, torch.Tensor): + + if video_transform_fn is not None: + # Apply transformation + torch_video = video_transform_fn({"image": torch_video}) + torch_video = torch_video["image"] + + if torch_video.shape[self.stack_frame_dim] < total_frames_to_sample: + # This is very unlikely but, for a safer side. + delta = total_frames_to_sample - torch_video.shape[self.stack_frame_dim] + clip_height, clip_width = torch_video.shape[-2:] + if self.stack_frame_dim == 0: + delta_frames = torch.zeros(size=(delta, 3, clip_height, clip_width)) + else: + delta_frames = torch.zeros(size=(3, delta, clip_height, clip_width)) + torch_video = torch.cat( + [torch_video, delta_frames], dim=self.stack_frame_dim + ) + elif torch_video.shape[self.stack_frame_dim] > total_frames_to_sample: + # truncate + if self.stack_frame_dim == 0: + torch_video = torch_video[:total_frames_to_sample, ...] + else: + torch_video = torch_video[:, :total_frames_to_sample, ...] + + assert torch_video.shape[self.stack_frame_dim] % clips_per_video == 0 + + return self.convert_to_clips(video=torch_video, n_clips=clips_per_video) + else: + return None diff --git a/Adaptive Frequency Filters/data/video_reader/key_frame_reader.py b/Adaptive Frequency Filters/data/video_reader/key_frame_reader.py new file mode 100644 index 0000000..4eae68a --- /dev/null +++ b/Adaptive Frequency Filters/data/video_reader/key_frame_reader.py @@ -0,0 +1,116 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from typing import Optional, List +import torch +import numpy as np +import av +from torch import Tensor + +from . import PyAVBaseReader, register_video_reader + + +@register_video_reader(name="pyav_key_frame_only") +class PyAVKeyFrameReader(PyAVBaseReader): + """ + PyAv video reader for reading key frames only + + Args: + opts: command line arguments + is_training (Optional[bool]): Training or validation mode. Default: `False` + """ + + def __init__(self, opts, is_training: Optional[bool] = False, *args, **kwargs): + super().__init__(opts=opts, is_training=is_training) + + def read_video( + self, filename: str, frame_indices: Optional[List] = None, *args, **kwargs + ) -> Optional[Tensor]: + # for key frames, we do not know the indices of key frames. + # so we can't use frame indices here + try: + # Check basic demo for Pyav usage: https://pyav.org/docs/stable/#basic-demo + with av.open(filename) as container: + stream = container.streams.video[0] + stream.codec_context.skip_frame = "NONKEY" + + if self.fast_decoding: + stream.thread_type = "AUTO" + + key_frames = [] + for frame in container.decode(video=0): + frame = self.frame_to_tensor(frame) + key_frames.append(frame) + + # [C, H, W] x N --> [N, C, H, W] or [C, N, H, W] + key_frames = torch.stack(key_frames, dim=self.stack_frame_dim) + return key_frames + except av.AVError as ave_error: + return None + + def process_video( + self, + vid_filename: str, + n_frames_per_clip: Optional[int] = -1, + clips_per_video: Optional[int] = 1, + video_transform_fn: Optional = None, + is_training: Optional[bool] = False, + *args, + **kwargs + ): + # [N, C, H, W] or [C, N, H, W] + torch_video = self.read_video(filename=vid_filename) + + if isinstance(torch_video, torch.Tensor): + + if video_transform_fn is not None: + # Apply transformation + torch_video = video_transform_fn({"image": torch_video}) + torch_video = torch_video["image"] + + if n_frames_per_clip == -1: + return self.convert_to_clips(video=torch_video, n_clips=clips_per_video) + + # select frames + total_frames = torch_video.shape[self.stack_frame_dim] + total_desired_frames = clips_per_video * n_frames_per_clip + + if is_training: + frame_ids = self.random_sampling( + desired_frames=total_desired_frames, total_frames=total_frames + ) + else: + frame_ids = self.uniform_sampling( + desired_frames=total_desired_frames, total_frames=total_frames + ) + + # [N, C, H, W] or [C, N, H, W] + torch_video = torch.index_select( + torch_video, dim=self.stack_frame_dim, index=frame_ids + ) + + if torch_video.shape[self.stack_frame_dim] < total_desired_frames: + # This is very unlikely but, for a safer side. + delta = total_desired_frames - torch_video.shape[self.stack_frame_dim] + clip_height, clip_width = torch_video.shape[-2:] + if self.stack_frame_dim == 0: + delta_frames = torch.zeros(size=(delta, 3, clip_height, clip_width)) + else: + delta_frames = torch.zeros(size=(3, delta, clip_height, clip_width)) + torch_video = torch.cat( + [torch_video, delta_frames], dim=self.stack_frame_dim + ) + elif torch_video.shape[self.stack_frame_dim] > total_desired_frames: + if self.stack_frame_dim == 0: + torch_video = torch_video[:total_desired_frames, ...] + else: + torch_video = torch_video[:, :total_desired_frames, ...] + + assert torch_video.shape[self.stack_frame_dim] % clips_per_video == 0 + + return self.convert_to_clips(video=torch_video, n_clips=clips_per_video) + else: + return None diff --git a/Adaptive Frequency Filters/engine/__init__.py b/Adaptive Frequency Filters/engine/__init__.py new file mode 100644 index 0000000..59a32e6 --- /dev/null +++ b/Adaptive Frequency Filters/engine/__init__.py @@ -0,0 +1,8 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from .training_engine import Trainer +from .evaluation_engine import Evaluator diff --git a/Adaptive Frequency Filters/engine/detection_utils/__init__.py b/Adaptive Frequency Filters/engine/detection_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Adaptive Frequency Filters/engine/detection_utils/coco_map.py b/Adaptive Frequency Filters/engine/detection_utils/coco_map.py new file mode 100644 index 0000000..5e1af79 --- /dev/null +++ b/Adaptive Frequency Filters/engine/detection_utils/coco_map.py @@ -0,0 +1,110 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import numpy as np +from pycocotools.cocoeval import COCOeval +from pycocotools.coco import COCO +from pycocotools import mask as maskUtils +from typing import Optional, List +from contextlib import redirect_stdout +import io + +from utils import logger + + +def coco_evaluation( + opts, + predictions: List[np.ndarray], + split: Optional[str] = "val", + year: Optional[int] = 2017, + *args, + **kwargs +) -> None: + root = getattr(opts, "dataset.root_val", None) + ann_file = os.path.join(root, "annotations/instances_{}{}.json".format(split, year)) + bkrnd_id = 0 if getattr(opts, "dataset.detection.no_background_id", False) else 1 + coco = COCO(ann_file) + coco_categories = sorted(coco.getCatIds()) + + coco_id_to_contiguous_id = { + coco_id: i + bkrnd_id for i, coco_id in enumerate(coco_categories) + } + contiguous_id_to_coco_id = {v: k for k, v in coco_id_to_contiguous_id.items()} + + coco_results = {"bbox": []} + + for i, (image_id, boxes, labels, scores, masks) in enumerate(predictions): + if labels.shape[0] == 0: + continue + + boxes = boxes.tolist() + labels = labels.tolist() + scores = scores.tolist() + + coco_results["bbox"].extend( + [ + { + "image_id": image_id, + "category_id": contiguous_id_to_coco_id[labels[k]], + "bbox": [ + box[0], + box[1], + box[2] - box[0], + box[3] - box[1], + ], # to xywh format + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + + if masks is not None: + if "segm" not in coco_results: + coco_results["segm"] = [] + + # Masks are in [N, H, W] format + rles = [ + maskUtils.encode( + np.array(mask[:, :, np.newaxis], dtype=np.uint8, order="F") + )[0] + for mask in masks + ] + + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + coco_results["segm"].extend( + [ + { + "image_id": image_id, + "category_id": contiguous_id_to_coco_id[labels[seg_id]], + "segmentation": rle, + "score": scores[seg_id], + } + for seg_id, rle in enumerate(rles) + ] + ) + + if len(coco_results) == 0: + logger.error("Cannot compute COCO stats. Please check the predictions") + + for iou_type, coco_result in coco_results.items(): + with redirect_stdout(io.StringIO()): + coco_dt = COCO.loadRes(coco, coco_result) + + # Run COCO evaluation + coco_eval = COCOeval(coco, coco_dt, iou_type) + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + +def compute_quant_scores(opts, predictions: List, *args, **kwargs) -> None: + dataset_name = getattr(opts, "dataset.name", None) + if dataset_name.find("coco") > -1: + coco_evaluation(opts=opts, predictions=predictions) + else: + raise NotImplementedError diff --git a/Adaptive Frequency Filters/engine/eval_detection.py b/Adaptive Frequency Filters/engine/eval_detection.py new file mode 100644 index 0000000..37b26a0 --- /dev/null +++ b/Adaptive Frequency Filters/engine/eval_detection.py @@ -0,0 +1,410 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os.path +import numpy as np +import torch +import multiprocessing +from torch.nn import functional as F +from tqdm import tqdm +import glob +from typing import Optional, Dict +from torch import Tensor, nn +from torchvision.transforms import functional as F_vision +from PIL import Image + +from common import SUPPORTED_IMAGE_EXTNS +from options.opts import get_detection_eval_arguments +from affnet import get_model +from affnet.models.detection.ssd import DetectionPredTuple +from data import create_eval_loader +from data.datasets.detection.coco_base import COCODetection +from utils.tensor_utils import to_numpy, image_size_from_opts +from utils.common_utils import device_setup, create_directories +from utils.ddp_utils import is_master +from utils import logger +from engine.utils import print_summary +from engine.detection_utils.coco_map import compute_quant_scores +from utils.visualization_utils import draw_bounding_boxes +from utils.download_utils import get_local_path + +from .utils import autocast_fn, get_batch_size + +# Evaluation on MSCOCO detection task +object_names = COCODetection.class_names() + + +def predict_and_save( + opts, + input_tensor: Tensor, + model: nn.Module, + input_np: Optional[np.ndarray] = None, + device: Optional = torch.device("cpu"), + is_coco_evaluation: Optional[bool] = False, + file_name: Optional[str] = None, + output_stride: Optional[int] = 32, + orig_h: Optional[int] = None, + orig_w: Optional[int] = None, + *args, + **kwargs +): + """ + This function makes a prediction on the input tensor and optionally save the detection results + Args: + opts: command-line arguments + input_tensor (Tensor): Input tensor of size :math:`(1, C, H, W)` + model (nn.Module): detection model + input_np (Optional[np.ndarray]): Input numpy image of size :math:`(H, W, C)`. Used only for visualization purposes. Defaults to None + device (Optional[str]): Device. Defaults to cpu. + is_coco_evaluation (Optional[bool]): Evaluating on MS-COCO object detection. Defaults to False. + file_name (Optional[bool]): File name for storing detection results. Only applicable when `is_coco_evaluation` is False. Defaults to None. + output_stride (Optional[int]): Output stride. This is used to ensure that image size is divisible by this factor. Defaults to 32. + orig_h (Optional[int]): Original height of the input image. Useful for visualizing detection results. Defaults to None. + orig_w (Optional[int]): Original width of the input image. Useful for visualizing detection results. Defaults to None. + """ + mixed_precision_training = getattr(opts, "common.mixed_precision", False) + mixed_precision_dtype = getattr(opts, "common.mixed_precision_dtype", "float16") + + if input_np is None and not is_coco_evaluation: + input_np = to_numpy(input_tensor).squeeze( # convert to numpy + 0 + ) # remove batch dimension + + curr_height, curr_width = input_tensor.shape[2:] + + # check if dimensions are multiple of output_stride, otherwise, we get dimension mismatch errors. + # if not, then resize them + new_h = (curr_height // output_stride) * output_stride + new_w = (curr_width // output_stride) * output_stride + + if new_h != curr_height or new_w != curr_width: + # resize the input image, so that we do not get dimension mismatch errors in the forward pass + input_tensor = F.interpolate( + input=input_tensor, + size=(new_h, new_w), + mode="bilinear", + align_corners=False, + ) + + # move data to device + input_tensor = input_tensor.to(device) + + with autocast_fn( + enabled=mixed_precision_training, amp_precision=mixed_precision_dtype + ): + # prediction + # We dot scale inside the prediction function because we resize the input tensor such + # that the dimensions are divisible by output stride. + prediction: DetectionPredTuple = model.predict(input_tensor, is_scaling=False) + + if orig_w is None: + assert orig_h is None + orig_h, orig_w = input_np.shape[:2] + elif orig_h is None: + assert orig_w is None + orig_h, orig_w = input_np.shape[:2] + assert orig_h is not None and orig_w is not None + + # convert tensors to numpy + boxes = prediction.boxes.cpu().numpy() + labels = prediction.labels.cpu().numpy() + scores = prediction.scores.cpu().numpy() + + masks = prediction.masks + + # Ensure that there is at least one mask + if masks is not None and masks.shape[0] > 0: + # masks are in [N, H, W] format + # for interpolation, add a dummy batch dimension + masks = F.interpolate( + masks.unsqueeze(0), + size=(orig_h, orig_w), + mode="bilinear", + align_corners=True, + ).squeeze(0) + # convert to binary masks + masks = masks > 0.5 + masks = masks.cpu().numpy() + + boxes[..., 0::2] = np.clip(a_min=0, a_max=orig_w, a=boxes[..., 0::2] * orig_w) + boxes[..., 1::2] = np.clip(a_min=0, a_max=orig_h, a=boxes[..., 1::2] * orig_h) + + if is_coco_evaluation: + return boxes, labels, scores, masks + + detection_res_file_name = None + if file_name is not None: + file_name = file_name.split(os.sep)[-1].split(".")[0] + ".jpg" + res_dir = "{}/detection_results".format(getattr(opts, "common.exp_loc", None)) + if not os.path.isdir(res_dir): + os.makedirs(res_dir, exist_ok=True) + detection_res_file_name = "{}/{}".format(res_dir, file_name) + + draw_bounding_boxes( + image=input_np, + boxes=boxes, + labels=labels, + scores=scores, + masks=masks, + # some models may not use background class which is present in class names. + # adjust the class names + object_names=object_names[-model.n_detection_classes :] + if hasattr(model, "n_detection_classes") + else object_names, + is_bgr_format=True, + save_path=detection_res_file_name, + ) + + +def predict_labeled_dataset(opts, **kwargs): + device = getattr(opts, "dev.device", torch.device("cpu")) + + # set-up data loaders + val_loader = create_eval_loader(opts) + + # set-up the model + model = get_model(opts) + model.eval() + model = model.to(device=device) + print_summary(opts=opts, model=model) + + if model.training: + logger.warning("Model is in training mode. Switching to evaluation mode") + model.eval() + + with torch.no_grad(): + predictions = [] + for img_idx, batch in tqdm(enumerate(val_loader)): + samples, targets = batch["samples"], batch["targets"] + + batch_size = get_batch_size(samples) + if isinstance(samples, Dict): + assert "image" in samples, "samples does not contain image key" + input_tensor = samples["image"] + else: + input_tensor = samples + + assert ( + batch_size == 1 + ), "We recommend to run detection evaluation with a batch size of 1" + + orig_w = targets["image_width"].item() + orig_h = targets["image_height"].item() + image_id = targets["image_id"].item() + + boxes, labels, scores, masks = predict_and_save( + opts=opts, + input_tensor=input_tensor, + model=model, + device=device, + is_coco_evaluation=True, + orig_w=orig_w, + orig_h=orig_h, + ) + + predictions.append([image_id, boxes, labels, scores, masks]) + + compute_quant_scores(opts=opts, predictions=predictions) + + +def read_and_process_image(opts, image_fname: str, *args, **kwargs): + input_img = Image.open(image_fname).convert("RGB") + input_np = np.array(input_img) + orig_w, orig_h = input_img.size + + # Resize the image to the resolution that detector supports + res_h, res_w = image_size_from_opts(opts) + input_img = F_vision.resize( + input_img, + size=[res_h, res_w], + interpolation=F_vision.InterpolationMode.BILINEAR, + ) + input_tensor = F_vision.pil_to_tensor(input_img) + input_tensor = input_tensor.float().div(255.0).unsqueeze(0) + return input_tensor, input_np, orig_h, orig_w + + +def predict_image(opts, image_fname, **kwargs): + image_fname = get_local_path(opts, image_fname) + if not os.path.isfile(image_fname): + logger.error("Image file does not exist at: {}".format(image_fname)) + + input_tensor, input_imp_copy, orig_h, orig_w = read_and_process_image( + opts, image_fname=image_fname + ) + + image_fname = image_fname.split(os.sep)[-1] + + device = getattr(opts, "dev.device", torch.device("cpu")) + # set-up the model + model = get_model(opts) + model.eval() + model = model.to(device=device) + print_summary(opts=opts, model=model) + + if model.training: + logger.warning("Model is in training mode. Switching to evaluation mode") + model.eval() + + with torch.no_grad(): + predict_and_save( + opts=opts, + input_tensor=input_tensor, + input_np=input_imp_copy, + file_name=image_fname, + model=model, + device=device, + orig_h=orig_h, + orig_w=orig_w, + ) + + +def predict_images_in_folder(opts, **kwargs): + img_folder_path = getattr(opts, "evaluation.detection.path", None) + if img_folder_path is None: + logger.error( + "Image folder is not passed. Please use --evaluation.detection.path as an argument to pass the location of image folder".format( + img_folder_path + ) + ) + elif not os.path.isdir(img_folder_path): + logger.error( + "Image folder does not exist at: {}. Please check".format(img_folder_path) + ) + + img_files = [] + for e in SUPPORTED_IMAGE_EXTNS: + img_files_with_extn = glob.glob("{}/*{}".format(img_folder_path, e)) + if len(img_files_with_extn) > 0 and isinstance(img_files_with_extn, list): + img_files.extend(img_files_with_extn) + + if len(img_files) == 0: + logger.error( + "Number of image files found at {}: {}".format( + img_folder_path, len(img_files) + ) + ) + + logger.log( + "Number of image files found at {}: {}".format(img_folder_path, len(img_files)) + ) + + device = getattr(opts, "dev.device", torch.device("cpu")) + # set-up the model + model = get_model(opts) + model.eval() + model = model.to(device=device) + print_summary(opts=opts, model=model) + + if model.training: + logger.warning("Model is in training mode. Switching to evaluation mode") + model.eval() + + with torch.no_grad(): + for img_idx, image_fname in enumerate(img_files): + input_tensor, input_np, orig_h, orig_w = read_and_process_image( + opts=opts, image_fname=image_fname + ) + + image_fname = image_fname.split(os.sep)[-1] + + predict_and_save( + opts=opts, + input_tensor=input_tensor, + input_np=input_np, + file_name=image_fname, + model=model, + device=device, + orig_h=orig_h, + orig_w=orig_w, + ) + + +def main_detection_evaluation(**kwargs): + opts = get_detection_eval_arguments() + + dataset_name = getattr(opts, "dataset.name", "imagenet") + if dataset_name.find("coco") > -1: + # replace model specific datasets (e.g., coco_ssd) with general COCO dataset + setattr(opts, "dataset.name", "coco") + + # device set-up + opts = device_setup(opts) + + node_rank = getattr(opts, "ddp.rank", 0) + if node_rank < 0: + logger.error("--rank should be >=0. Got {}".format(node_rank)) + + is_master_node = is_master(opts) + + # create the directory for saving results + save_dir = getattr(opts, "common.results_loc", "results") + run_label = getattr(opts, "common.run_label", "run_1") + exp_dir = "{}/{}".format(save_dir, run_label) + setattr(opts, "common.exp_loc", exp_dir) + logger.log("Results (if any) will be stored here: {}".format(exp_dir)) + + create_directories(dir_path=exp_dir, is_master_node=is_master_node) + + num_gpus = getattr(opts, "dev.num_gpus", 1) + if num_gpus < 2: + cls_norm_type = getattr(opts, "model.normalization.name", "batch_norm_2d") + if cls_norm_type.find("sync") > -1: + # replace sync_batch_norm with standard batch norm on PU + setattr( + opts, "model.normalization.name", cls_norm_type.replace("sync_", "") + ) + setattr( + opts, + "model.classification.normalization.name", + cls_norm_type.replace("sync_", ""), + ) + + # we disable the DDP setting for evaluation tasks + setattr(opts, "ddp.use_distributed", False) + + # No of data workers = no of CPUs (if not specified or -1) + n_cpus = multiprocessing.cpu_count() + dataset_workers = getattr(opts, "dataset.workers", -1) + + if dataset_workers == -1: + setattr(opts, "dataset.workers", n_cpus) + + # We are not performing any operation like resizing and cropping on images + # Because image dimensions are different, we process 1 sample at a time. + setattr(opts, "dataset.train_batch_size0", 1) + setattr(opts, "dataset.val_batch_size0", 1) + setattr(opts, "dev.device_id", None) + + eval_mode = getattr(opts, "evaluation.detection.mode", None) + + if eval_mode == "single_image": + num_classes = getattr(opts, "model.detection.n_classes", 81) + assert num_classes is not None + + # test a single image + img_f_name = getattr(opts, "evaluation.detection.path", None) + predict_image(opts, img_f_name, **kwargs) + elif eval_mode == "image_folder": + num_seg_classes = getattr(opts, "model.detection.n_classes", 81) + assert num_seg_classes is not None + + # test all images in a folder + predict_images_in_folder(opts=opts, **kwargs) + elif eval_mode == "validation_set": + # evaluate and compute stats for labeled image dataset + # This is useful for generating results for validation set and compute quantitative results + predict_labeled_dataset(opts=opts, **kwargs) + else: + logger.error( + "Supported modes are single_image, image_folder, and validation_set. Got: {}".format( + eval_mode + ) + ) + + +if __name__ == "__main__": + main_detection_evaluation() diff --git a/Adaptive Frequency Filters/engine/eval_segmentation.py b/Adaptive Frequency Filters/engine/eval_segmentation.py new file mode 100644 index 0000000..5b9bcf6 --- /dev/null +++ b/Adaptive Frequency Filters/engine/eval_segmentation.py @@ -0,0 +1,500 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- +import copy +import torch +import multiprocessing +import os +from tqdm import tqdm +import glob +from typing import Optional, Tuple, List +from torch import Tensor, nn +from torch.nn import functional as F +from torchvision.transforms import functional as F_vision +from PIL import Image + +from utils import logger +from utils.tensor_utils import image_size_from_opts +from options.opts import get_segmentation_eval_arguments +from utils.common_utils import device_setup, create_directories +from utils.ddp_utils import is_master +from affnet import get_model +from data import create_eval_loader +from utils.color_map import Colormap +from engine.utils import print_summary +from common import SUPPORTED_IMAGE_EXTNS +from metrics.confusion_mat import ConfusionMatrix +from utils.visualization_utils import convert_to_cityscape_format +from utils.download_utils import get_local_path + +from .utils import autocast_fn + +""" +Notes: + +1) We have separate scripts for evaluating segmentation models because the size of input images varies and +we do not want to apply any resizing operations to input because that distorts the quality and hurts the performance. + +2) [Optional] We want to save the outputs in the same size as that of the input image. +""" + + +def predict_and_save( + opts, + input_tensor: Tensor, + file_name: str, + orig_h: int, + orig_w: int, + model: nn.Module, + target_mask: Optional[Tensor] = None, + device: Optional = torch.device("cpu"), + conf_mat: Optional[ConfusionMatrix] = None, + color_map: List = None, + orig_image: Optional[Image.Image] = None, + adjust_label: Optional[int] = 0, + is_cityscape: Optional[bool] = False, + *args, + **kwargs +) -> None: + """Predict the segmentation mask and optionally save them""" + + mixed_precision_training = getattr(opts, "common.mixed_precision", False) + mixed_precision_dtype = getattr(opts, "common.mixed_precision_dtype", "float16") + + output_stride = getattr(opts, "model.segmentation.output_stride", 16) + if output_stride == 1: + # we set it to 32 because most of the ImageNet models have 5 downsampling stages (2^5 = 32) + output_stride = 32 + + if orig_image is None: + orig_image = F_vision.to_pil_image(input_tensor[0]) + + curr_h, curr_w = input_tensor.shape[2:] + + # check if dimensions are multiple of output_stride, otherwise, we get dimension mismatch errors. + # if not, then resize them + new_h = (curr_h // output_stride) * output_stride + new_w = (curr_w // output_stride) * output_stride + + if new_h != curr_h or new_w != curr_w: + # resize the input image, so that we do not get dimension mismatch errors in the forward pass + input_tensor = F.interpolate( + input=input_tensor, size=(new_h, new_w), mode="bilinear", align_corners=True + ) + + file_name = file_name.split(os.sep)[-1].split(".")[0] + ".png" + + # move data to device + input_tensor = input_tensor.to(device) + if target_mask is not None: + target_mask = target_mask.to(device) + + with autocast_fn( + enabled=mixed_precision_training, amp_precision=mixed_precision_dtype + ): + # prediction + pred = model(input_tensor, orig_size=(orig_h, orig_w)) + + if isinstance(pred, Tuple) and len(pred) == 2: + # when segmentation mask from decoder and auxiliary decoder are returned + pred = pred[0] + elif isinstance(pred, Tensor): + pred = pred + else: + raise NotImplementedError( + "Predicted must should be either an instance of Tensor or Tuple[Tensor, Tensor]" + ) + + num_classes = pred.shape[1] + pred_mask = pred.argmax(1).squeeze(0) + + if target_mask is not None and conf_mat is not None: + conf_mat.update( + ground_truth=target_mask.flatten(), + prediction=pred_mask.flatten(), + n_classes=num_classes, + ) + + save_dir = getattr(opts, "common.exp_loc", None) + pred_mask = pred_mask + adjust_label + if target_mask is not None: + target_mask = target_mask + adjust_label + + # Visualize results + if getattr(opts, "evaluation.segmentation.apply_color_map", False): + # For some dataset, we need to adjust the labels. For example, we need adjust by 1 for ADE20k + + draw_colored_masks( + opts=opts, + orig_image=orig_image, + pred_mask=pred_mask, + target_mask=target_mask, + results_location=save_dir, + color_map=color_map, + file_name=file_name, + ) + + if getattr(opts, "evaluation.segmentation.save_masks", False): + draw_binary_masks( + opts=opts, + pred_mask=pred_mask, + file_name=file_name, + is_cityscape=is_cityscape, + results_location=save_dir, + ) + + +def draw_binary_masks( + opts, + pred_mask: Tensor, + file_name: str, + results_location: str, + is_cityscape: Optional[bool] = False, +) -> None: + """Save masks whose values ranges between 0 and number_of_classes - 1""" + no_color_mask_dir = "{}/predictions_no_cmap".format(results_location) + if not os.path.isdir(no_color_mask_dir): + os.makedirs(no_color_mask_dir, exist_ok=True) + no_color_mask_f_name = "{}/{}".format(no_color_mask_dir, file_name) + + if is_cityscape: + # convert mask values to cityscapes format + pred_mask = convert_to_cityscape_format(img=pred_mask) + pred_mask_pil = F_vision.to_pil_image(pred_mask.byte()) + pred_mask_pil.save(no_color_mask_f_name) + + +def draw_colored_masks( + opts, + orig_image: Image.Image, + pred_mask: Tensor, + target_mask: Tensor, + file_name: str, + results_location: str, + color_map: Optional[List] = None, +) -> None: + """Apply color map to segmentation masks""" + + alpha = getattr(opts, "evaluation.segmentation.overlay_mask_weight", 0.5) + save_overlay_rgb_pred = getattr( + opts, "evaluation.segmentation.save_overlay_rgb_pred", False + ) + + if color_map is None: + color_map = Colormap().get_color_map_list() + + # convert predicted tensor to PIL images, apply color map and save + pred_mask_pil = F_vision.to_pil_image(pred_mask.byte()) + pred_mask_pil.putpalette(color_map) + pred_mask_pil = pred_mask_pil.convert("RGB") + pred_color_mask_dir = "{}/predictions_cmap".format(results_location) + if not os.path.isdir(pred_color_mask_dir): + os.makedirs(pred_color_mask_dir, exist_ok=True) + color_mask_f_name = "{}/{}".format(pred_color_mask_dir, file_name) + pred_mask_pil.save(color_mask_f_name) + logger.log("Predicted mask is saved at: {}".format(color_mask_f_name)) + + if target_mask is not None: + # convert target tensor to PIL images, apply colormap, and save + target_mask_pil = F_vision.to_pil_image(target_mask.byte()) + target_mask_pil.putpalette(color_map) + target_mask_pil = target_mask_pil.convert("RGB") + target_color_mask_dir = "{}/gt_cmap".format(results_location) + if not os.path.isdir(target_color_mask_dir): + os.makedirs(target_color_mask_dir, exist_ok=True) + gt_color_mask_f_name = "{}/{}".format(target_color_mask_dir, file_name) + target_mask_pil.save(gt_color_mask_f_name) + logger.log("Target mask is saved at: {}".format(color_mask_f_name)) + + if save_overlay_rgb_pred and orig_image is not None: + # overlay predicted mask on top of original image and save + + if pred_mask_pil.size != orig_image.size: + # resize if input image size is not the same as predicted mask. + # this is likely in case of labeled datasets where we use transforms on the input image + orig_image = F_vision.resize( + orig_image, + size=pred_mask_pil.size[::-1], + interpolation=F_vision.InterpolationMode.BILINEAR, + ) + + overlayed_img = Image.blend(pred_mask_pil, orig_image, alpha=alpha) + overlay_mask_dir = "{}/predictions_overlay".format(results_location) + if not os.path.isdir(overlay_mask_dir): + os.makedirs(overlay_mask_dir, exist_ok=True) + overlay_mask_f_name = "{}/{}".format(overlay_mask_dir, file_name) + overlayed_img.save(overlay_mask_f_name) + logger.log("RGB image blended with mask is saved at: {}".format(overlay_mask_f_name)) + + # save original image + rgb_image_dir = "{}/rgb_images".format(results_location) + if not os.path.isdir(rgb_image_dir): + os.makedirs(rgb_image_dir, exist_ok=True) + rgb_image_f_name = "{}/{}".format(rgb_image_dir, file_name) + orig_image.save(rgb_image_f_name) + logger.log("Original RGB image is saved at: {}".format(overlay_mask_f_name)) + + +def predict_labeled_dataset(opts, **kwargs) -> None: + device = getattr(opts, "dev.device", torch.device("cpu")) + mixed_precision_training = getattr(opts, "common.mixed_precision", False) + dataset_name = getattr(opts, "dataset.name", "") + + # set-up data loaders + val_loader = create_eval_loader(opts) + + # set-up the model + model = get_model(opts) + model.eval() + model = model.to(device=device) + print_summary(opts=opts, model=model) + + if model.training: + logger.log("Model is in training mode. Switching to evaluation mode") + model.eval() + + color_map = Colormap().get_color_map_list() + adjust_label = 0 + is_cityscape = False + conf_mat = ConfusionMatrix() + if hasattr(val_loader.dataset, "color_palette"): + color_map = val_loader.dataset.color_palette() + + if hasattr(val_loader.dataset, "adjust_mask_value"): + adjust_label = val_loader.dataset.adjust_mask_value() + + if dataset_name is not None and dataset_name.lower() == "cityscapes": + is_cityscape = True + + with torch.no_grad(): + for batch_id, batch in enumerate(val_loader): + samples, targets = batch["samples"], batch["targets"] + batch_size = samples.shape[0] + assert ( + batch_size == 1 + ), "We recommend to run segmentation evaluation with a batch size of 1" + + predict_and_save( + opts=opts, + input_tensor=samples, + file_name=targets["file_name"][0], + orig_w=targets["im_width"][0].item(), + orig_h=targets["im_height"][0].item(), + model=model, + target_mask=targets["mask"], + device=device, + mixed_precision_training=mixed_precision_training, + conf_mat=conf_mat, + color_map=color_map, + adjust_label=adjust_label, + is_cityscape=is_cityscape, + ) + + acc_global, acc, iu = conf_mat.compute() + logger.info("Quantitative results") + print( + "global correct: {:.2f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.2f}".format( + acc_global.item() * 100, + ["{:.2f}".format(i) for i in (acc * 100).tolist()], + ["{:.2f}".format(i) for i in (iu * 100).tolist()], + iu.mean().item() * 100, + ) + ) + + is_city_dataset = getattr(opts, "dataset.name", "") == "cityscapes" + if is_city_dataset: + from .segmentation_utils.cityscapes_iou import eval_cityscapes + + pred_dir = "{}/predictions_no_cmap/".format( + getattr(opts, "common.exp_loc", None) + ) + gt_dir = os.path.join(getattr(opts, "dataset.root_val", None), "gtFine/val/") + eval_cityscapes(pred_dir=pred_dir, gt_dir=gt_dir) + + +def read_and_process_image(opts, image_fname: str, *args, **kwargs): + input_img = Image.open(image_fname).convert("RGB") + input_pil = copy.deepcopy(input_img) + orig_w, orig_h = input_img.size + + # Resize the image while maitaining the aspect ratio + res_h, res_w = image_size_from_opts(opts) + + input_img = F_vision.resize( + input_img, + size=min(res_h, res_w), + interpolation=F_vision.InterpolationMode.BILINEAR, + ) + input_tensor = F_vision.pil_to_tensor(input_img) + input_tensor = input_tensor.float().div(255.0).unsqueeze(0) + return input_tensor, input_pil, orig_h, orig_w + + +def predict_image(opts, image_fname: str, **kwargs) -> None: + image_fname = get_local_path(opts, image_fname) + + if not os.path.isfile(image_fname): + logger.error("Image file does not exist at: {}".format(image_fname)) + + input_tensor, input_pil, orig_h, orig_w = read_and_process_image( + opts, image_fname=image_fname + ) + + image_fname = image_fname.split(os.sep)[-1] + + device = getattr(opts, "dev.device", torch.device("cpu")) + # set-up the model + model = get_model(opts) + model.eval() + model = model.to(device=device) + print_summary(opts=opts, model=model) + + if model.training: + logger.log("Model is in training mode. Switching to evaluation mode") + model.eval() + + with torch.no_grad(): + predict_and_save( + opts=opts, + input_tensor=input_tensor, + file_name=image_fname, + orig_h=orig_h, + orig_w=orig_w, + model=model, + target_mask=None, + device=device, + orig_image=input_pil, + ) + + +def predict_images_in_folder(opts, **kwargs) -> None: + img_folder_path = getattr(opts, "evaluation.segmentation.path", None) + if img_folder_path is None: + logger.error( + "Location of the folder containing images is not passed. Please use --evaluation.segmentation.path " + "as an argument to pass the location of the folder".format(img_folder_path) + ) + elif not os.path.isdir(img_folder_path): + logger.error( + "Folder containing images does not exist at: {}. Please check".format( + img_folder_path + ) + ) + + img_files = [] + for e in SUPPORTED_IMAGE_EXTNS: + img_files_with_extn = glob.glob("{}/*{}".format(img_folder_path, e)) + if len(img_files_with_extn) > 0 and isinstance(img_files_with_extn, list): + img_files.extend(img_files_with_extn) + + if len(img_files) == 0: + logger.error( + "Number of image files found at {}: {}".format( + img_folder_path, len(img_files) + ) + ) + + logger.log( + "Number of image files found at {}: {}".format(img_folder_path, len(img_files)) + ) + + device = getattr(opts, "dev.device", torch.device("cpu")) + mixed_precision_training = getattr(opts, "common.mixed_precision", False) + # set-up the model + model = get_model(opts) + model.eval() + model = model.to(device=device) + print_summary(opts=opts, model=model) + + if model.training: + logger.log("Model is in training mode. Switching to evaluation mode") + model.eval() + + with torch.no_grad(): + for image_fname in tqdm(img_files): + input_tensor, input_pil, orig_h, orig_w = read_and_process_image( + opts, image_fname=image_fname + ) + + image_fname = image_fname.split(os.sep)[-1] + + predict_and_save( + opts=opts, + input_tensor=input_tensor, + file_name=image_fname, + orig_h=orig_h, + orig_w=orig_w, + model=model, + target_mask=None, + device=device, + mixed_precision_training=mixed_precision_training, + orig_image=input_pil, + ) + + +def main_segmentation_evaluation(**kwargs) -> None: + opts = get_segmentation_eval_arguments() + + # device set-up + opts = device_setup(opts) + + node_rank = getattr(opts, "ddp.rank", 0) + if node_rank < 0: + logger.error("--rank should be >=0. Got {}".format(node_rank)) + + is_master_node = is_master(opts) + + # create the directory for saving results + save_dir = getattr(opts, "common.results_loc", "results") + run_label = getattr(opts, "common.run_label", "run_1") + exp_dir = "{}/{}".format(save_dir, run_label) + setattr(opts, "common.exp_loc", exp_dir) + logger.log("Results (if any) will be stored here: {}".format(exp_dir)) + + create_directories(dir_path=exp_dir, is_master_node=is_master_node) + + num_gpus = getattr(opts, "dev.num_gpus", 1) + # we disable the DDP setting for evaluating segmentation tasks + setattr(opts, "ddp.use_distributed", False) + + # No of data workers = no of CPUs (if not specified or -1) + n_cpus = multiprocessing.cpu_count() + dataset_workers = getattr(opts, "dataset.workers", -1) + + if dataset_workers == -1: + setattr(opts, "dataset.workers", n_cpus) + + # We are not performing any operation like resizing and cropping on images + # Because image dimensions are different, we process 1 sample at a time. + setattr(opts, "dataset.train_batch_size0", 1) + setattr(opts, "dataset.val_batch_size0", 1) + setattr(opts, "dev.device_id", None) + + eval_mode = getattr(opts, "evaluation.segmentation.mode", None) + + if eval_mode == "single_image": + num_seg_classes = getattr(opts, "model.segmentation.n_classes", 21) + assert num_seg_classes is not None + + # test a single image + img_f_name = getattr(opts, "evaluation.segmentation.path", None) + predict_image(opts, img_f_name, **kwargs) + elif eval_mode == "image_folder": + num_seg_classes = getattr(opts, "model.segmentation.n_classes", 21) + assert num_seg_classes is not None + + # test all images in a folder + # This is useful for generating results for test set + predict_images_in_folder(opts=opts, **kwargs) + elif eval_mode == "validation_set": + # evaluate and compute stats for labeled image dataset + # This is useful for generating results for validation set and compute quantitative results + predict_labeled_dataset(opts=opts, **kwargs) + else: + logger.error( + "Supported modes are single_image, image_folder, and validation_set. Got: {}".format( + eval_mode + ) + ) diff --git a/Adaptive Frequency Filters/engine/evaluation_engine.py b/Adaptive Frequency Filters/engine/evaluation_engine.py new file mode 100644 index 0000000..34e2e7e --- /dev/null +++ b/Adaptive Frequency Filters/engine/evaluation_engine.py @@ -0,0 +1,204 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +import time + +from metrics import Statistics, metric_monitor +from options.parse_args import parse_validation_metric_names +from utils.ddp_utils import is_master +from utils import logger +from utils.common_utils import move_to_device +from engine.utils import print_summary +from common import DEFAULT_LOG_FREQ, SUPPORTED_VIDEO_CLIP_VOTING_FN + +from .utils import get_batch_size, autocast_fn + + +class Evaluator(object): + def __init__(self, opts, model, eval_loader): + super(Evaluator, self).__init__() + + self.opts = opts + + self.model = model + + self.eval_loader = eval_loader + + self.device = getattr(opts, "dev.device", torch.device("cpu")) + self.use_distributed = getattr(self.opts, "ddp.use_distributed", False) + self.is_master_node = is_master(opts) + self.stage_name = getattr(opts, "common.eval_stage_name", "evaluation") + + self.mixed_precision_training = getattr(opts, "common.mixed_precision", False) + self.mixed_precision_dtype = getattr( + opts, "common.mixed_precision_dtype", "float16" + ) + + ( + self.metric_names, + self.ckpt_metric, + self.ckpt_submetric, + ) = parse_validation_metric_names(self.opts) + + if self.is_master_node: + print_summary(opts=self.opts, model=self.model) + + # inference modality based eval function + self.eval_fn = self.eval_fn_image + inference_modality = getattr(opts, "common.inference_modality", "image") + if inference_modality is not None and inference_modality.lower() == "video": + self.eval_fn = self.eval_fn_video + + def eval_fn_image(self, model): + log_freq = getattr(self.opts, "common.log_freq", DEFAULT_LOG_FREQ) + + evaluation_stats = Statistics( + metric_names=self.metric_names, is_master_node=self.is_master_node + ) + + model.eval() + if model.training and self.is_master_node: + logger.warning("Model is in training mode. Switching to evaluation mode") + model.eval() + + with torch.no_grad(): + epoch_start_time = time.time() + total_samples = len(self.eval_loader) + processed_samples = 0 + + for batch_id, batch in enumerate(self.eval_loader): + batch = move_to_device(opts=self.opts, x=batch, device=self.device) + + samples, targets = batch["samples"], batch["targets"] + + batch_size = get_batch_size(samples) + + with autocast_fn( + enabled=self.mixed_precision_training, + amp_precision=self.mixed_precision_dtype, + ): + # prediction + pred_label = model(samples) + + processed_samples += batch_size + metrics = metric_monitor( + self.opts, + pred_label=pred_label, + target_label=targets, + loss=torch.tensor(0.0, dtype=torch.float, device=self.device), + use_distributed=self.use_distributed, + metric_names=self.metric_names, + ) + + evaluation_stats.update( + metric_vals=metrics, batch_time=0.0, n=batch_size + ) + + if batch_id % log_freq == 0 and self.is_master_node: + evaluation_stats.iter_summary( + epoch=-1, + n_processed_samples=processed_samples, + total_samples=total_samples, + elapsed_time=epoch_start_time, + learning_rate=0.0, + ) + + evaluation_stats.epoch_summary(epoch=-1, stage=self.stage_name) + + def eval_fn_video(self, model): + log_freq = getattr(self.opts, "common.log_freq", DEFAULT_LOG_FREQ) + + evaluation_stats = Statistics( + metric_names=self.metric_names, is_master_node=self.is_master_node + ) + + model.eval() + if model.training and self.is_master_node: + logger.warning("Model is in training mode. Switching to evaluation mode") + model.eval() + + num_clips_per_video = getattr(self.opts, "sampler.bs.clips_per_video", 1) + voting_fn = getattr( + self.opts, "model.video_classification.clip_out_voting_fn", "sum" + ) + if voting_fn is None: + voting_fn = "sum" + voting_fn = voting_fn.lower() + + with torch.no_grad(): + epoch_start_time = time.time() + total_samples = len(self.eval_loader) + processed_samples = 0 + + for batch_id, batch in enumerate(self.eval_loader): + batch = move_to_device(opts=self.opts, x=batch, device=self.device) + + samples, targets = batch["samples"], batch["targets"] + # target_label is Batch*Num_clips + batch_size_ = get_batch_size(samples) + batch_size = batch_size_ // num_clips_per_video + if batch_size_ != (batch_size * num_clips_per_video): + logger.log( + "Skipping batch. Expected batch size= {}. Got: (bxc:{}x{})".format( + batch_size_, batch_size, num_clips_per_video + ) + ) + continue + + with autocast_fn( + enabled=self.mixed_precision_training, + amp_precision=self.mixed_precision_dtype, + ): + # prediction + pred_label = model(samples) + + targets = targets.reshape(batch_size, num_clips_per_video) + # label is the same for all clips in the video + targets = targets[:, 0] + pred_label = pred_label.reshape(batch_size, num_clips_per_video, -1) + + if voting_fn == "sum": + pred_label = torch.sum(pred_label, dim=1) + elif voting_fn == "max": + pred_label = torch.max(pred_label, dim=1) + else: + logger.error( + "--model.video-classification.clip-out-fusion-fn can be {}. Got: {}".format( + SUPPORTED_VIDEO_CLIP_VOTING_FN, voting_fn + ) + ) + + processed_samples += batch_size + metrics = metric_monitor( + self.opts, + pred_label=pred_label, + target_label=targets, + loss=torch.tensor(0.0, dtype=torch.float, device=self.device), + use_distributed=self.use_distributed, + metric_names=self.metric_names, + ) + + evaluation_stats.update( + metric_vals=metrics, batch_time=0.0, n=batch_size + ) + + if batch_id % log_freq == 0 and self.is_master_node: + evaluation_stats.iter_summary( + epoch=-1, + n_processed_samples=processed_samples, + total_samples=total_samples, + elapsed_time=epoch_start_time, + learning_rate=0.0, + ) + + evaluation_stats.epoch_summary(epoch=-1, stage=self.stage_name) + + def run(self): + eval_start_time = time.time() + self.eval_fn(model=self.model) + eval_end_time = time.time() - eval_start_time + logger.log("Evaluation took {} seconds".format(eval_end_time)) diff --git a/Adaptive Frequency Filters/engine/segmentation_utils/__init__.py b/Adaptive Frequency Filters/engine/segmentation_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Adaptive Frequency Filters/engine/segmentation_utils/cityscapes_iou.py b/Adaptive Frequency Filters/engine/segmentation_utils/cityscapes_iou.py new file mode 100644 index 0000000..8a7f623 --- /dev/null +++ b/Adaptive Frequency Filters/engine/segmentation_utils/cityscapes_iou.py @@ -0,0 +1,42 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as cityscapes_semseg_eval +import os +import glob + +from utils import logger + + +def eval_cityscapes(pred_dir: str, gt_dir: str) -> None: + """Utility to evaluate on cityscapes dataset""" + cityscapes_semseg_eval.args.predictionPath = pred_dir + cityscapes_semseg_eval.args.predictionWalk = None + cityscapes_semseg_eval.args.JSONOutput = False + cityscapes_semseg_eval.args.colorized = False + + gt_img_list = glob.glob(os.path.join(gt_dir, "*", "*_gtFine_labelIds.png")) + if len(gt_img_list) == 0: + logger.error("Cannot find ground truth images at: {}".format(gt_dir)) + + pred_img_list = [] + for gt in gt_img_list: + pred_img_list.append( + cityscapes_semseg_eval.getPrediction(cityscapes_semseg_eval.args, gt) + ) + + results = cityscapes_semseg_eval.evaluateImgLists( + pred_img_list, gt_img_list, cityscapes_semseg_eval.args + ) + + logger.info("Evaluation results summary") + eval_res_str = "\n\t IoU_cls: {:.2f} \n\t iIOU_cls: {:.2f} \n\t IoU_cat: {:.2f} \n\t iIOU_cat: {:.2f}".format( + 100.0 * results["averageScoreClasses"], + 100.0 * results["averageScoreInstClasses"], + 100.0 * results["averageScoreCategories"], + 100.0 * results["averageScoreInstCategories"], + ) + print(eval_res_str) diff --git a/Adaptive Frequency Filters/engine/training_engine.py b/Adaptive Frequency Filters/engine/training_engine.py new file mode 100644 index 0000000..4e417fd --- /dev/null +++ b/Adaptive Frequency Filters/engine/training_engine.py @@ -0,0 +1,1004 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import sys +import traceback +import torch +from torch import Tensor +import copy +import gc +import time +import shutil +from typing import Dict +import wandb +from torch.cuda.amp import autocast +from torch.nn import functional as F +import random +from typing import Union, List, Optional +import numpy as np +from itertools import product + +from data.transforms.image_torch import RandomMixup, RandomCutmix +from engine.utils import print_summary, get_log_writers +from metrics import Statistics, metric_monitor +from common import DEFAULT_ITERATIONS, DEFAULT_EPOCHS, DEFAULT_LOG_FREQ +from options.parse_args import parse_validation_metric_names +from utils import logger +from utils.common_utils import create_directories, move_to_device +from utils.ddp_utils import is_master, dist_barrier +from utils.tensor_utils import reduce_tensor_sum, tensor_to_python_float +from utils.checkpoint_utils import ( + copy_weights, + save_checkpoint, + save_interval_checkpoint, +) +from loss_landscape import landscape_utils as ll_utils + + +from .utils import get_batch_size, autocast_fn, log_metrics + + +class Trainer(object): + """ + This class defines the training and validation code for training models with affnet + """ + + def __init__( + self, + opts, + model, + validation_loader, + training_loader, + criterion, + optimizer, + scheduler, + gradient_scalar, + start_epoch: int = 0, + start_iteration: int = 0, + best_metric: float = 0.0, + model_ema=None, + *args, + **kwargs + ) -> None: + super(Trainer, self).__init__() + + self.opts = opts + + self.model = model + self.model_ema = model_ema + self.criteria = criterion + self.optimizer = optimizer + self.scheduler = scheduler + self.gradient_scalar = gradient_scalar + + self.val_loader = validation_loader + self.train_loader = training_loader + + self.device = getattr(opts, "dev.device", torch.device("cpu")) + + self.start_epoch = start_epoch + self.best_metric = best_metric + self.train_iterations = start_iteration + + self.is_master_node = is_master(opts) + self.max_iterations_reached = False + self.max_iterations = getattr( + self.opts, "scheduler.max_iterations", DEFAULT_ITERATIONS + ) + self.use_distributed = getattr(self.opts, "ddp.use_distributed", False) + self.log_freq = getattr(self.opts, "common.log_freq", DEFAULT_LOG_FREQ) + self.accum_freq = getattr(self.opts, "common.accum_freq", 1) + self.accum_after_epoch = getattr(self.opts, "common.accum_after_epoch", 0) + + self.mixed_precision_training = getattr(opts, "common.mixed_precision", False) + self.mixed_precision_dtype = getattr( + opts, "common.mixed_precision_dtype", "float16" + ) + + self.train_metric_names = getattr(opts, "stats.train", ["loss"]) + if isinstance(self.train_metric_names, str): + self.train_metric_names = [self.train_metric_names] + + assert isinstance( + self.train_metric_names, list + ), "Type of metric names should be list. Got: {}".format( + type(self.train_metric_names) + ) + + if "loss" not in self.train_metric_names: + self.train_metric_names.append(self.train_metric_names) + + ( + self.val_metric_names, + self.ckpt_metric, + self.ckpt_submetric, + ) = parse_validation_metric_names(self.opts) + + self.save_all_checkpoints = getattr( + self.opts, "common.save_all_checkpoints", False + ) + + self.save_location = getattr(opts, "common.exp_loc", "results/run_1") + + self.log_writers = get_log_writers(self.opts, save_location=self.save_location) + if self.is_master_node: + print_summary( + opts=self.opts, + model=self.model, + criteria=self.criteria, + optimizer=self.optimizer, + scheduler=self.scheduler, + ) + + self.adjust_norm_mom = None + if getattr(opts, "model.normalization.adjust_bn_momentum.enable", False): + from affnet.layers import AdjustBatchNormMomentum + + self.adjust_norm_mom = AdjustBatchNormMomentum(opts=opts) + if self.is_master_node: + logger.log( + "Batch normalization momentum will be annealed during training." + ) + print(self.adjust_norm_mom) + + # sample-efficient training + self.cache_dict = None + self.sample_efficient_training = getattr( + opts, "dataset.sample_efficient_training.enable", False + ) + self.sample_confidence = getattr( + opts, "dataset.sample_efficient_training.sample_confidence", 0.5 + ) + self.find_easy_samples_every_k_epoch = getattr( + opts, + "dataset.sample_efficient_training.find_easy_samples_every_k_epochs", + 5, + ) + self.min_sample_frequency = getattr( + opts, "dataset.sample_efficient_training.min_sample_frequency", 5 + ) + if self.sample_efficient_training: + self.train_loader_set = copy.deepcopy(self.train_loader) + self.sample_ids_orig = self.train_loader_set.get_sample_indices() + n_samples = len(self.sample_ids_orig) + self.running_sum_tensor = torch.zeros( + (n_samples,), device=self.device, dtype=torch.int + ) + self.running_sum_tensor.requires_grad = False + if self.is_master_node: + logger.log("Configuring for sample efficient training") + + # recent versions of PyTorch support setting grads to None, for better performance + # To be explored in Future + # self.optimizer.zero_grad(set_to_none=True) + self.set_grad_to_none = False + + save_interval_freq = getattr(opts, "common.save_interval_freq", 0) + # save interval checkpoints every `save_interval_freq` updates on the master node + self.save_interval = self.is_master_node and save_interval_freq > 0 + self.save_interval_freq = save_interval_freq + + def compute_grad_norm(self): + parameters = [p for p in self.model.parameters() if p.grad is not None] + if len(parameters) == 0: + return None + + norm_type = 2.0 # L2 norm + + inv_scale = 1.0 / self.gradient_scalar.get_scale() + total_norm = torch.norm( + torch.stack( + [ + torch.norm(p.grad.detach() * inv_scale, norm_type).to(self.device) + for p in parameters + ] + ), + norm_type, + ) + if total_norm.isnan() or total_norm.isinf(): + return None + return total_norm + + def apply_mixup_transforms(self, data): + # Apply mixup transforms on classification tasks + opts = self.opts + mixup_transforms = [] + if getattr(opts, "image_augmentation.mixup.enable", False): + n_classes = getattr(opts, "model.classification.n_classes", None) + if n_classes is None: + logger.error("Please specify number of classes. Got None.") + mixup_transforms.append(RandomMixup(opts=opts, num_classes=n_classes)) + + if getattr(opts, "image_augmentation.cutmix.enable", False): + n_classes = getattr(opts, "model.classification.n_classes", None) + if n_classes is None: + logger.error("Please specify number of classes. Got None.") + mixup_transforms.append(RandomCutmix(opts=opts, num_classes=n_classes)) + + if len(mixup_transforms) > 0: + _mixup_transform = random.choice(mixup_transforms) + data = _mixup_transform(data) + return data + + def _zero_grad(self): + if self.set_grad_to_none: + self.optimizer.zero_grad(set_to_none=True) + else: + self.optimizer.zero_grad() + + def train_epoch(self, epoch): + time.sleep(2) # To prevent possible deadlock during epoch transition + + if self.is_master_node: + logger.double_dash_line() + logger.debug( + "Training epoch {} with {} samples".format( + epoch, self.train_loader.samples_in_dataset() + ) + ) + + train_stats = Statistics( + metric_names=self.train_metric_names, is_master_node=self.is_master_node + ) + + self.model.train() + accum_freq = self.accum_freq if epoch >= self.accum_after_epoch else 1 + max_norm = getattr(self.opts, "common.grad_clip", None) + + # set the gradient to zero or None + self._zero_grad() + + epoch_start_time = time.time() + batch_load_start = time.time() + grad_norm = torch.tensor([0.0], dtype=torch.float, device=self.device) + for batch_id, batch in enumerate(self.train_loader): + if self.train_iterations > self.max_iterations: + self.max_iterations_reached = True + return -1, -1 + + # move to device + batch = move_to_device(opts=self.opts, x=batch, device=self.device) + # apply mix-up transforms if any + batch = self.apply_mixup_transforms(data=batch) + + batch_load_toc = time.time() - batch_load_start + + samples, targets = batch["samples"], batch["targets"] + + batch_size = get_batch_size(samples) + + # update the learning rate + self.optimizer = self.scheduler.update_lr( + optimizer=self.optimizer, epoch=epoch, curr_iter=self.train_iterations + ) + + # adjust bn momentum + if self.adjust_norm_mom is not None: + self.adjust_norm_mom.adjust_momentum( + model=self.model, epoch=epoch, iteration=self.train_iterations + ) + + with autocast_fn( + enabled=self.mixed_precision_training, + amp_precision=self.mixed_precision_dtype, + ): + # prediction + pred_label = self.model(samples) + # compute loss + loss_dict_or_tensor: Union[Dict, Tensor] = self.criteria( + input_sample=samples, + prediction=pred_label, + target=targets, + epoch=epoch, + iterations=self.train_iterations, + ) + + if isinstance(loss_dict_or_tensor, Dict): + if "total_loss" not in loss_dict_or_tensor.keys(): + logger.error( + "total_loss key is required for loss functions that return outputs as dictionary." + ) + loss = loss_dict_or_tensor["total_loss"] + elif isinstance(loss_dict_or_tensor, Tensor): + loss = loss_dict_or_tensor + else: + logger.error("Loss value should be an instance of Tensor or Dict") + + if isinstance(loss, torch.Tensor) and torch.isnan(loss): + logger.error("Nan encountered in the loss.") + + # perform the backward pass with gradient accumulation [Optional] + self.gradient_scalar.scale(loss).backward() + + if (batch_id + 1) % accum_freq == 0: + if max_norm is not None: + # For gradient clipping, unscale the gradients and then clip them + self.gradient_scalar.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), max_norm=max_norm + ) + + if "grad_norm" in self.train_metric_names: + # compute grad_norm for logging purposes. + # We can't use the output of clip_grad_norm_ because it returns the total norm before clipping + grad_norm = self.compute_grad_norm() + + # optimizer step + self.gradient_scalar.step(optimizer=self.optimizer) + # update the scale for next batch + self.gradient_scalar.update() + # set the gradient to zero or None + self._zero_grad() + + self.train_iterations += 1 + + if self.model_ema is not None: + self.model_ema.update_parameters(self.model) + + metrics = metric_monitor( + self.opts, + pred_label=pred_label, + target_label=targets, + loss=loss_dict_or_tensor, + grad_norm=grad_norm, + use_distributed=self.use_distributed, + metric_names=self.train_metric_names, + ) + + train_stats.update( + metric_vals=metrics, batch_time=batch_load_toc, n=batch_size + ) + + # save the checkpoint every N updates + if ( + self.save_interval + and (self.train_iterations % self.save_interval_freq) == 0 + ): + save_interval_checkpoint( + iterations=self.train_iterations, + epoch=epoch, + model=self.model, + optimizer=self.optimizer, + best_metric=loss.item(), + save_dir=self.save_location, + gradient_scalar=self.gradient_scalar, + not_intermediate_checkpoint=False, + ) + logger.info( + "Checkpoints saved after {} updates at: {}".format( + self.train_iterations, self.save_location + ), + print_line=True, + ) + + if batch_id % self.log_freq == 0 and self.is_master_node: + lr = self.scheduler.retrieve_lr(self.optimizer) + train_stats.iter_summary( + epoch=epoch, + n_processed_samples=self.train_iterations, + total_samples=self.max_iterations, + learning_rate=lr, + elapsed_time=epoch_start_time, + ) + + batch_load_start = time.time() + + avg_loss = train_stats.avg_statistics( + metric_name="loss", sub_metric_name="total_loss" + ) + train_stats.epoch_summary(epoch=epoch, stage="training") + avg_ckpt_metric = train_stats.avg_statistics( + metric_name=self.ckpt_metric, sub_metric_name=self.ckpt_submetric + ) + + gc.collect() + + return avg_loss, avg_ckpt_metric + + def val_epoch(self, epoch, model, extra_str=""): + if self.val_loader is None: + return 0.0, 0.0 + + time.sleep(2) # To prevent possible deadlock during epoch transition + validation_stats = Statistics( + metric_names=self.val_metric_names, is_master_node=self.is_master_node + ) + + if "coco_map" in self.val_metric_names: + from metrics.coco_map import COCOEvaluator + + coco_evaluator = COCOEvaluator( + opts=self.opts, use_distributed=self.use_distributed + ) + else: + coco_evaluator = None + + model.eval() + if model.training and self.is_master_node: + logger.warning("Model is in training mode. Switching to evaluation mode") + model.eval() + + with torch.no_grad(): + epoch_start_time = time.time() + total_samples = len(self.val_loader) + processed_samples = 0 + lr = self.scheduler.retrieve_lr(self.optimizer) + for batch_id, batch in enumerate(self.val_loader): + batch = move_to_device(opts=self.opts, x=batch, device=self.device) + + samples, targets = batch["samples"], batch["targets"] + + batch_size = get_batch_size(samples) + + with autocast_fn( + enabled=self.mixed_precision_training, + amp_precision=self.mixed_precision_dtype, + ): + # prediction + pred_label = model(samples) + # compute loss + loss_dict_or_tensor = self.criteria( + input_sample=samples, + prediction=pred_label, + target=targets, + ) + + processed_samples += batch_size + + metrics = metric_monitor( + self.opts, + pred_label=pred_label, + target_label=targets, + loss=loss_dict_or_tensor, + use_distributed=self.use_distributed, + metric_names=self.val_metric_names, + is_evaluation=True, + ) + + validation_stats.update( + metric_vals=metrics, batch_time=0.0, n=batch_size + ) + + if coco_evaluator is not None: + coco_evaluator.prepare_predictions( + predictions=pred_label, targets=targets + ) + + if batch_id % self.log_freq == 0 and self.is_master_node: + validation_stats.iter_summary( + epoch=epoch, + n_processed_samples=processed_samples, + total_samples=total_samples, + elapsed_time=epoch_start_time, + learning_rate=lr, + ) + + validation_stats.epoch_summary(epoch=epoch, stage="validation" + extra_str) + avg_loss = validation_stats.avg_statistics( + metric_name="loss", sub_metric_name="total_loss" + ) + avg_ckpt_metric = validation_stats.avg_statistics( + metric_name=self.ckpt_metric, sub_metric_name=self.ckpt_submetric + ) + + if coco_evaluator is not None: + # synchronize across different processes and aggregate the results + coco_evaluator.gather_coco_results() + coco_map = coco_evaluator.summarize_coco_results() + + if self.ckpt_metric == "coco_map" and "bbox" in coco_map: + avg_ckpt_metric = round(coco_map["bbox"], 5) + + if avg_ckpt_metric is None: + avg_ckpt_metric = avg_loss + + gc.collect() + + return avg_loss, avg_ckpt_metric + + def find_easy_samples(self, epoch, model, *args, **kwargs): + """ + This function identifies easy samples in the training set and removes them from training. + + .. note:: + Currently, this is implemented separately to avoid breaking the training and validation pipeline. In future, + this will be combined with main training loop to reduce overhead. + """ + + time.sleep(2) # To prevent possible deadlock during epoch transition + + model.eval() + if model.training and self.is_master_node: + logger.warning("Model is in training mode. Switching to evaluation mode") + model.eval() + + if self.is_master_node: + logger.log("Trying to find easy samples in epoch {}".format(epoch)) + + with torch.no_grad(): + easy_sample_ids_tensor = torch.zeros_like(self.running_sum_tensor) + + for batch_id, batch in enumerate(self.train_loader_set): + batch = move_to_device(opts=self.opts, x=batch, device=self.device) + + samples, targets = batch["samples"], batch["targets"] + + sample_ids = None + if "sample_id" in batch: + sample_ids = batch["sample_id"] + else: + self.sample_efficient_training = False + if self.is_master_node: + logger.log( + "Sample Ids are required in a batch for sample efficient training. " + "sample_id key not found in batch. Disabling sample efficient training." + ) + break + + if sample_ids is None: + logger.log("Sample Ids can't be none") + break + + with autocast_fn( + enabled=self.mixed_precision_training, + amp_precision=self.mixed_precision_dtype, + ): + # prediction + pred_label = model(samples) + pred_label = F.softmax(pred_label, dim=-1) + + pred_conf, pred_indices = torch.max(pred_label, dim=-1) + + easy_samples = torch.logical_and( + pred_indices.eq( + targets + ), # condition 1: Predicted label == Target label + pred_conf + >= self.sample_confidence, # condition 2: prediction confidence >= desired confidence + ) + + if easy_samples.numel() > 0: + easy_sample_ids = sample_ids[easy_samples] + # find easy samples as per condition 1 and 2 and set their values to 1 + easy_sample_ids_tensor[easy_sample_ids] = 1 + + # synchronize tensors + if self.use_distributed: + # sync across all GPUs. + easy_sample_ids_tensor = reduce_tensor_sum(easy_sample_ids_tensor) + + # some samples which are classified easy earlier may have been classified hard now. + easy_sample_ids_tensor[easy_sample_ids_tensor == 0] = -1 + + if self.is_master_node: + logger.debug( + "Number of easy samples found during epoch {} are {}".format( + epoch, + easy_sample_ids_tensor[easy_sample_ids_tensor > 0].sum().item(), + ) + ) + + self.running_sum_tensor = torch.clip( + self.running_sum_tensor + easy_sample_ids_tensor, + min=0, + max=self.min_sample_frequency, + ) + + if self.running_sum_tensor.sum() > 0: + skip_sample_ids = ( + self.running_sum_tensor >= self.min_sample_frequency + ).nonzero(as_tuple=True)[0] + + if skip_sample_ids.numel() > 0: + skip_samples = skip_sample_ids.cpu().numpy().tolist() + + new_sample_ids = [ + s_id + for s_id in self.sample_ids_orig + if s_id not in skip_sample_ids + ] + + # update the train loader indices + self.train_loader.update_indices(new_sample_ids) + + if self.is_master_node: + logger.debug( + "Number of samples to skip after epoch {} are {}".format( + epoch, len(skip_samples) + ) + ) + + def run(self, train_sampler=None): + if train_sampler is None and self.is_master_node: + logger.error("Train sampler cannot be None") + + copy_at_epoch = getattr(self.opts, "ema.copy_at_epoch", -1) + train_start_time = time.time() + + cfg_file = getattr(self.opts, "common.config_file", None) + if cfg_file is not None and self.is_master_node: + dst_cfg_file = "{}/config.yaml".format(self.save_location) + shutil.copy(src=cfg_file, dst=dst_cfg_file) + logger.info( + "Configuration file is stored here: {}".format( + logger.color_text(dst_cfg_file) + ) + ) + + keep_k_best_ckpts = getattr(self.opts, "common.k_best_checkpoints", 5) + ema_best_metric = self.best_metric + is_ema_best = False + + try: + max_epochs = getattr(self.opts, "scheduler.max_epochs", DEFAULT_EPOCHS) + max_checkpoint_metric = getattr( + self.opts, "stats.checkpoint_metric_max", False + ) + # TODO: to delete + # val_loss, val_ckpt_metric = self.val_epoch( + # epoch=-1, model=self.model + # ) + for epoch in range(self.start_epoch, max_epochs): + # Note that we are using our owm implementations of data samplers + # and we have defined this function for both distributed and non-distributed cases + train_sampler.set_epoch(epoch) + train_sampler.update_scales( + epoch=epoch, is_master_node=self.is_master_node + ) + + train_loss, train_ckpt_metric = self.train_epoch(epoch) + if self.opts.log_wandb and self.is_master_node: + wandb.log({'train_loss': train_loss}) + wandb.log({'train_top1': train_ckpt_metric}) + + val_loss, val_ckpt_metric = self.val_epoch( + epoch=epoch, model=self.model + ) + if self.opts.log_wandb and self.is_master_node: + wandb.log({'val_loss': val_loss}) + wandb.log({'val_top1': val_ckpt_metric}) + + if epoch == copy_at_epoch and self.model_ema is not None: + if self.is_master_node: + logger.log("Copying EMA weights") + # copy model_src weights to model_tgt + self.model = copy_weights( + model_tgt=self.model, model_src=self.model_ema + ) + if self.is_master_node: + logger.log("EMA weights copied") + logger.log("Running validation after Copying EMA model weights") + self.val_epoch(epoch=epoch, model=self.model) + + if max_checkpoint_metric: + is_best = val_ckpt_metric >= self.best_metric + self.best_metric = max(val_ckpt_metric, self.best_metric) + else: + is_best = val_ckpt_metric <= self.best_metric + self.best_metric = min(val_ckpt_metric, self.best_metric) + + val_ema_loss = None + val_ema_ckpt_metric = None + if self.model_ema is not None: + val_ema_loss, val_ema_ckpt_metric = self.val_epoch( + epoch=epoch, model=self.model_ema.ema_model, extra_str=" (EMA)" + ) + if self.opts.log_wandb and self.is_master_node: + wandb.log({'val_ema_loss': val_ema_loss}) + wandb.log({'val_ema_top1': val_ema_ckpt_metric}) + if max_checkpoint_metric: + is_ema_best = val_ema_ckpt_metric >= ema_best_metric + ema_best_metric = max(val_ema_ckpt_metric, ema_best_metric) + else: + is_ema_best = val_ema_ckpt_metric <= ema_best_metric + ema_best_metric = min(val_ema_ckpt_metric, ema_best_metric) + + # sample efficient training + if ( + self.sample_efficient_training + and (epoch + 1) % self.find_easy_samples_every_k_epoch == 0 + ): + self.find_easy_samples( + epoch=epoch, + model=self.model + if self.model_ema is not None + else self.model_ema.ema_model, + ) + + gc.collect() + + if self.is_master_node: + save_checkpoint( + iterations=self.train_iterations, + epoch=epoch, + model=self.model, + optimizer=self.optimizer, + best_metric=self.best_metric, + is_best=is_best, + save_dir=self.save_location, + model_ema=self.model_ema, + is_ema_best=is_ema_best, + ema_best_metric=ema_best_metric, + gradient_scalar=self.gradient_scalar, + max_ckpt_metric=max_checkpoint_metric, + k_best_checkpoints=keep_k_best_ckpts, + save_all_checkpoints=self.save_all_checkpoints, + ) + logger.info( + "Checkpoints saved at: {}".format(self.save_location), + print_line=True, + ) + + if self.is_master_node: + lr_list = self.scheduler.retrieve_lr(self.optimizer) + + for log_writer in self.log_writers: + log_metrics( + lrs=lr_list, + log_writer=log_writer, + train_loss=train_loss, + val_loss=val_loss, + epoch=epoch, + best_metric=self.best_metric, + val_ema_loss=val_ema_loss, + ckpt_metric_name=self.ckpt_metric, + train_ckpt_metric=train_ckpt_metric, + val_ckpt_metric=val_ckpt_metric, + val_ema_ckpt_metric=val_ema_ckpt_metric, + ) + + if self.max_iterations_reached: + if self.use_distributed: + dist_barrier() + + if self.is_master_node: + logger.info("Max. iterations for training reached") + break + except KeyboardInterrupt as e: + if self.is_master_node: + logger.log("Keyboard interruption. Exiting from early training") + raise e + except Exception as e: + if "out of memory" in str(e): + logger.log("OOM exception occured") + n_gpus = getattr(self.opts, "dev.num_gpus", 1) + for dev_id in range(n_gpus): + mem_summary = torch.cuda.memory_summary( + device=torch.device("cuda:{}".format(dev_id)), abbreviated=True + ) + logger.log("Memory summary for device id: {}".format(dev_id)) + print(mem_summary) + else: + logger.log( + "Exception occurred that interrupted the training. {}".format( + str(e) + ) + ) + print(e) + traceback.print_exc() + raise e + finally: + use_distributed = getattr(self.opts, "ddp.use_distributed", False) + if use_distributed: + torch.distributed.destroy_process_group() + + torch.cuda.empty_cache() + + for log_writer in self.log_writers: + log_writer.close() + + if self.is_master_node: + train_end_time = time.time() + hours, rem = divmod(train_end_time - train_start_time, 3600) + minutes, seconds = divmod(rem, 60) + train_time_str = "{:0>2}:{:0>2}:{:05.2f}".format( + int(hours), int(minutes), seconds + ) + logger.log("Training took {}".format(train_time_str)) + + def run_loss_landscape(self): + # Loss landscape code is adapted from https://github.com/xxxnell/how-do-vits-work + ll_start_time = time.time() + try: + n_points = getattr(self.opts, "loss_landscape.n_points", 32) + min_x = getattr(self.opts, "loss_landscape.min_x", -1.0) + max_x = getattr(self.opts, "loss_landscape.max_x", 1.0) + min_y = getattr(self.opts, "loss_landscape.min_y", -1.0) + max_y = getattr(self.opts, "loss_landscape.max_y", 1.0) + + if self.is_master_node: + logger.log( + "Loss landscape coord space params: \n\tmin_x={}\n\tmax_x={}\n\tmin_y={}\n\tmax_y={}\n\tn_points={}".format( + min_x, max_x, min_y, max_y, n_points + ) + ) + + ll_metrics = ["loss"] + ll_stats = Statistics( + metric_names=ll_metrics, is_master_node=self.is_master_node + ) + has_module = hasattr(self.model, "module") + model_name = ( + self.model.module.__class__.__name__ + if has_module + else self.model.__class__.__name__ + ) + + # copy the model and create bases + model = copy.deepcopy(self.model) + weight_state_0 = ( + copy.deepcopy(model.module.state_dict()) + if has_module + else copy.deepcopy(model.state_dict()) + ) + bases = ll_utils.create_bases( + model=model, device=self.device, has_module=has_module + ) + + xs = np.linspace(min_x, max_x, n_points) + ys = np.linspace(min_y, max_y, n_points) + + grid_a, grid_b = np.meshgrid(xs, ys, indexing="xy") + loss_surface = np.empty_like(grid_a) + + epoch = -1 + for coord_a, coord_b in product(range(n_points), range(n_points)): + epoch += 1 + coords_list = [grid_a[coord_a, coord_b], grid_b[coord_a, coord_b]] + weight_state_1 = copy.deepcopy(weight_state_0) + gs = [{k: r * bs[k] for k in bs} for r, bs in zip(coords_list, bases)] + gs = { + k: torch.sum(torch.stack([g[k] for g in gs]), dim=0) + + weight_state_1[k] + for k in gs[0] + } + + # load the weights + model.module.load_state_dict( + gs + ) if has_module else model.load_state_dict(gs) + + model = model.to(device=self.device) + model.eval() + + total_samples = len(self.val_loader) + with torch.no_grad(): + epoch_start_time = time.time() + processed_samples = 0 + for batch_id, batch in enumerate(self.val_loader): + batch = move_to_device( + opts=self.opts, x=batch, device=self.device + ) + samples, targets = batch["samples"], batch["targets"] + + batch_size = get_batch_size(samples) + processed_samples += batch_size + + # make the prediction and compute loss + pred_label = model(samples) + loss_dict_or_tensor: Union[Dict, Tensor] = self.criteria( + input_sample=samples, + prediction=pred_label, + target=targets, + ) + + if isinstance(loss_dict_or_tensor, Dict): + if "total_loss" not in loss_dict_or_tensor.keys(): + logger.error( + "total_loss key is required for loss functions that return outputs as dictionary." + ) + loss = loss_dict_or_tensor["total_loss"] + elif isinstance(loss_dict_or_tensor, Tensor): + loss = loss_dict_or_tensor + else: + logger.error( + "Loss value should be an instance of Tensor or Dict" + ) + + if isinstance(loss, torch.Tensor) and torch.isnan(loss): + logger.error("Nan encountered in the loss.") + + metrics = metric_monitor( + self.opts, + pred_label=pred_label, + target_label=targets, + loss=loss_dict_or_tensor, + use_distributed=self.use_distributed, + metric_names=ll_metrics, + is_evaluation=True, + ) + + ll_stats.update( + metric_vals=metrics, batch_time=0.0, n=batch_size + ) + + if batch_id % self.log_freq == 0 and self.is_master_node: + ll_stats.iter_summary( + epoch=epoch, + n_processed_samples=processed_samples, + total_samples=total_samples, + elapsed_time=epoch_start_time, + learning_rate=0.0, + ) + + avg_loss = ll_stats.avg_statistics( + metric_name="loss", sub_metric_name="total_loss" + ) + loss_surface[coord_a, coord_b] = avg_loss + if self.is_master_node: + print( + "x: {:.2f}, y: {:.2f}, loss: {:.2f}".format( + coords_list[0], coords_list[1], avg_loss + ) + ) + + if self.is_master_node: + lr_list = [0.0] + + for log_writer in self.log_writers: + log_metrics( + lrs=lr_list, + log_writer=log_writer, + train_loss=0.0, + val_loss=avg_loss, + epoch=epoch, + best_metric=self.best_metric, + val_ema_loss=None, + ckpt_metric_name=None, + train_ckpt_metric=None, + val_ckpt_metric=None, + val_ema_ckpt_metric=None, + ) + + gc.collect() + # take a small nap + time.sleep(1) + + if self.is_master_node: + ll_utils.plot_save_graphs( + save_dir=self.save_location, + model_name=model_name, + grid_a=grid_a, + grid_b=grid_b, + loss_surface=loss_surface, + resolution=n_points, + ) + except KeyboardInterrupt as e: + if self.is_master_node: + logger.log("Keyboard interruption. Exiting from early training") + raise e + except Exception as e: + if "out of memory" in str(e): + logger.log("OOM exception occured") + n_gpus = getattr(self.opts, "dev.num_gpus", 1) + for dev_id in range(n_gpus): + mem_summary = torch.cuda.memory_summary( + device=torch.device("cuda:{}".format(dev_id)), abbreviated=True + ) + logger.log("Memory summary for device id: {}".format(dev_id)) + print(mem_summary) + else: + logger.log( + "Exception occurred that interrupted the training. {}".format( + str(e) + ) + ) + print(e) + raise e + finally: + if self.use_distributed: + torch.distributed.destroy_process_group() + + torch.cuda.empty_cache() + + if self.is_master_node: + ll_end_time = time.time() + hours, rem = divmod(ll_end_time - ll_start_time, 3600) + minutes, seconds = divmod(rem, 60) + train_time_str = "{:0>2}:{:0>2}:{:05.2f}".format( + int(hours), int(minutes), seconds + ) + logger.log("Loss landspace evaluation took {}".format(train_time_str)) diff --git a/Adaptive Frequency Filters/engine/utils.py b/Adaptive Frequency Filters/engine/utils.py new file mode 100644 index 0000000..61406cc --- /dev/null +++ b/Adaptive Frequency Filters/engine/utils.py @@ -0,0 +1,172 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from utils import logger +import torch +from torch import Tensor +from typing import Optional, Dict, Union, List, Any +import gc +from torch.cuda.amp import autocast + +from utils.ddp_utils import is_master +from utils.tensor_utils import create_rand_tensor +from utils.common_utils import create_directories + +str_to_torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16} + + +def autocast_fn(enabled: bool, amp_precision: Optional[str] = "float16"): + if enabled: + # If AMP is enabled, ensure that: + # 1. Device is CUDA + # 2. dtype is FLOAT16 or BFLOAT16 + if amp_precision not in str_to_torch_dtype: + logger.error( + "For Mixed-precision training, supported dtypes are {}. Got: {}".format( + list(str_to_torch_dtype.keys()), amp_precision + ) + ) + + if not torch.cuda.is_available(): + logger.error("For mixed-precision training, CUDA device is required.") + + return autocast(enabled=enabled, dtype=str_to_torch_dtype[amp_precision]) + else: + return autocast(enabled=False) + + +def print_summary( + opts, + model, + criteria: Optional = None, + optimizer: Optional = None, + scheduler: Optional = None, +) -> None: + if is_master(opts): + logger.log(logger.color_text("Model")) + print(model) + dev = getattr(opts, "dev.device", torch.device("cpu")) + try: + inp_tensor = create_rand_tensor(opts, device=dev) + + if hasattr(model, "module"): + model.module.profile_model(inp_tensor) + else: + model.profile_model(inp_tensor) + del inp_tensor + except Exception as e: + pass + + if criteria is not None: + # print criteria + logger.log(logger.color_text("Loss function")) + print("{}".format(criteria)) + + if optimizer is not None: + logger.log(logger.color_text("Optimizer")) + print("{}".format(optimizer)) + + if scheduler is not None: + logger.log(logger.color_text("Learning rate scheduler")) + print("{}".format(scheduler)) + + gc.collect() + + +def get_batch_size(x: Union[Tensor, Dict, List]) -> int: + if isinstance(x, Tensor): + return x.shape[0] + elif isinstance(x, Dict) and "image" in x: + return get_batch_size(x["image"]) + elif isinstance(x, List): + return len(x) + else: + raise NotImplementedError(f"Invalid type {type(x)}") + + +def log_metrics( + lrs: Union[List, float], + log_writer, + train_loss: float, + val_loss: float, + epoch: int, + best_metric: float, + val_ema_loss: Optional[float] = None, + ckpt_metric_name: Optional[str] = None, + train_ckpt_metric: Optional[float] = None, + val_ckpt_metric: Optional[float] = None, + val_ema_ckpt_metric: Optional[float] = None, +) -> None: + if not isinstance(lrs, list): + lrs = [lrs] + for g_id, lr_val in enumerate(lrs): + log_writer.add_scalar("LR/Group-{}".format(g_id), round(lr_val, 6), epoch) + + log_writer.add_scalar("Train/Loss", round(train_loss, 2), epoch) + log_writer.add_scalar("Val/Loss", round(val_loss, 2), epoch) + log_writer.add_scalar("Common/Best Metric", round(best_metric, 2), epoch) + if val_ema_loss is not None: + log_writer.add_scalar("Val_EMA/Loss", round(val_ema_loss, 2), epoch) + + # If val checkpoint metric is different from loss, add that too + if ckpt_metric_name is not None and ckpt_metric_name != "loss": + if train_ckpt_metric is not None: + log_writer.add_scalar( + "Train/{}".format(ckpt_metric_name.title()), + round(train_ckpt_metric, 2), + epoch, + ) + if val_ckpt_metric is not None: + log_writer.add_scalar( + "Val/{}".format(ckpt_metric_name.title()), + round(val_ckpt_metric, 2), + epoch, + ) + if val_ema_ckpt_metric is not None: + log_writer.add_scalar( + "Val_EMA/{}".format(ckpt_metric_name.title()), + round(val_ema_ckpt_metric, 2), + epoch, + ) + + +def get_log_writers(opts: Dict[str, Any], save_location: Optional[str]): + is_master_node = is_master(opts) + + log_writers = [] + if not is_master_node: + return log_writers + + tensorboard_logging = getattr(opts, "common.tensorboard_logging", False) + if tensorboard_logging and save_location is not None: + try: + from torch.utils.tensorboard import SummaryWriter + except ImportError as e: + logger.log( + "Unable to import SummaryWriter from torch.utils.tensorboard. Disabling tensorboard logging" + ) + SummaryWriter = None + + if SummaryWriter is not None: + exp_dir = "{}/tb_logs".format(save_location) + create_directories(dir_path=exp_dir, is_master_node=is_master_node) + log_writers.append( + SummaryWriter(log_dir=exp_dir, comment="Training and Validation logs") + ) + + bolt_logging = getattr(opts, "common.bolt_logging", False) + if bolt_logging: + try: + from internal.utils.bolt_logger import BoltLogger + except ModuleNotFoundError: + BoltLogger = None + + if BoltLogger is None: + logger.log("Unable to import bolt. Disabling bolt logging") + else: + log_writers.append(BoltLogger()) + + return log_writers diff --git a/Adaptive Frequency Filters/loss_fn/__init__.py b/Adaptive Frequency Filters/loss_fn/__init__.py new file mode 100644 index 0000000..b906b9c --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/__init__.py @@ -0,0 +1,120 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from .base_criteria import BaseCriteria +import os +import importlib +from utils import logger +import argparse + +LOSS_REGISTRY = {} + + +def register_loss_fn(name): + def register_loss_fn_class(cls): + if name in LOSS_REGISTRY: + raise ValueError( + "Cannot register duplicate loss function ({})".format(name) + ) + + if not issubclass(cls, BaseCriteria): + raise ValueError( + "Criteria ({}: {}) must extend BaseCriteria".format(name, cls.__name__) + ) + + LOSS_REGISTRY[name] = cls + return cls + + return register_loss_fn_class + + +def build_loss_fn(opts): + loss_fn_category = getattr(opts, "loss.category", "classification").lower() + loss_fn = None + if loss_fn_category in LOSS_REGISTRY: + loss_fn = LOSS_REGISTRY[loss_fn_category](opts) + else: + temp_list = list(LOSS_REGISTRY.keys()) + temp_str = "Loss function ({}) not yet supported. \n Supported loss functions are:".format( + loss_fn_category + ) + for i, m_name in enumerate(temp_list): + temp_str += "\n\t {}: {}".format(i, logger.color_text(m_name)) + logger.error(temp_str) + + return loss_fn + + +def general_loss_fn_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="Loss function arguments", description="Loss function arguments" + ) + + group.add_argument( + "--loss.category", + type=str, + default="classification", + help="Loss function category (classification,segmentation)", + ) + group.add_argument( + "--loss.ignore-idx", type=int, default=-1, help="Ignore idx in loss function" + ) + + return parser + + +def neural_aug_loss_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="Arguments related to Neural Aug loss function", + description="Arguments related to Neural Aug loss function", + ) + + group.add_argument( + "--loss.neural-aug.perceptual-metric", + type=str, + default="psnr", + help="Name of the perceptual metric", + ) + + group.add_argument( + "--loss.neural-aug.target-value", + type=float, + default=20.0, + nargs="+", + help="Target value of augmented tensor", + ) + + group.add_argument( + "--loss.neural-aug.curriculum-method", + type=str, + default="linear", + choices=["linear", "cosine"], + help="Use perceptual score for identifying the samples are good or not. ", + ) + return parser + + +def arguments_loss_fn(parser: argparse.ArgumentParser): + parser = general_loss_fn_args(parser=parser) + parser = neural_aug_loss_args(parser=parser) + + # add loss function specific arguments + for k, v in LOSS_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + return parser + + +# automatically import the loss functions +loss_fn_dir = os.path.dirname(__file__) +for file in os.listdir(loss_fn_dir): + path = os.path.join(loss_fn_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + loss_fn_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("loss_fn." + loss_fn_name) diff --git a/Adaptive Frequency Filters/loss_fn/base_criteria.py b/Adaptive Frequency Filters/loss_fn/base_criteria.py new file mode 100644 index 0000000..7cd6d11 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/base_criteria.py @@ -0,0 +1,51 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import nn, Tensor +import argparse +from typing import Any + + +class BaseCriteria(nn.Module): + def __init__(self, *args, **kwargs): + super(BaseCriteria, self).__init__() + self.eps = 1e-7 + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + return parser + + def forward( + self, input_sample: Any, prediction: Any, target: Any, *args, **kwargs + ) -> Tensor: + raise NotImplementedError + + @staticmethod + def _class_weights(target: Tensor, n_classes: int, norm_val: float = 1.1) -> Tensor: + class_hist: Tensor = torch.histc( + target.float(), bins=n_classes, min=0, max=n_classes - 1 + ) + mask_indices = class_hist == 0 + + # normalize between 0 and 1 by dividing by the sum + norm_hist = torch.div(class_hist, class_hist.sum()) + norm_hist = torch.add(norm_hist, norm_val) + + # compute class weights.. + # samples with more frequency will have less weight and vice-versa + class_wts = torch.div(torch.ones_like(class_hist), torch.log(norm_hist)) + + # mask the classes which do not have samples in the current batch + class_wts[mask_indices] = 0.0 + + return class_wts.to(device=target.device) + + def extra_repr(self) -> str: + return "" + + def __repr__(self): + return "{}({}\n)".format(self.__class__.__name__, self.extra_repr()) diff --git a/Adaptive Frequency Filters/loss_fn/base_neural_aug.py b/Adaptive Frequency Filters/loss_fn/base_neural_aug.py new file mode 100644 index 0000000..6c491f9 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/base_neural_aug.py @@ -0,0 +1,220 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import Tensor +import argparse +import math +from torch.nn import functional as F + +from utils import logger +from utils.ddp_utils import is_master + +from . import BaseCriteria + + +class BaseNeuralAug(BaseCriteria): + __supported_metrics = ["psnr"] + + def __init__(self, opts, *args, **kwargs): + super().__init__(opts, *args, **kwargs) + + perceptual_metric = getattr(opts, "loss.neural_aug.perceptual_metric", "psnr") + is_master_node = is_master(opts) + if perceptual_metric is None and is_master_node: + logger.error( + "Perceptual metric can't be none. " + "Please specify perceptual metric using --loss.neural-aug.perceptual-metric argument" + ) + if not isinstance(perceptual_metric, str) and is_master_node: + logger.error( + "The type of perceptual metric is not string. Got: {}".format( + type(perceptual_metric) + ) + ) + perceptual_metric = perceptual_metric.lower() + target_value = getattr(opts, "loss.neural_aug.target_value", None) + + self.curriculumn_learning = False + self.iteration_based_training = getattr( + opts, "scheduler.is_iteration_based", False + ) + self.target_str = f"{target_value}" + if perceptual_metric == "psnr": + if target_value is None and is_master_node: + logger.error("Target PSNR value can not be None.") + + if isinstance(target_value, (int, float)): + if target_value < 0: + if is_master_node: + logger.error( + "PSNR value should be >= 0 in {}. Got: {}".format( + self.__class__.__name__, target_value + ) + ) + # compute target MSE using below equation + # # PSNR = 20 log10(255) - 10 log10(MSE) + target_mse = 10.0 ** ((20.0 * math.log10(255.0) - target_value) / 10.0) + self.target_value = torch.ones(size=(1,), dtype=torch.float).fill_( + target_mse + ) + self.target_str = f"{target_value}" + elif isinstance(target_value, (list, tuple)) and len(target_value) == 2: + start_target_value = target_value[0] + end_target_value = target_value[1] + + if start_target_value < 0 or end_target_value < 0: + if is_master_node: + logger.error( + "PSNR value should be >= 0 in {}. Got: {}".format( + self.__class__.__name__, target_value + ) + ) + + # compute target MSE using below equation + # # PSNR = 20 log10(255) - 10 log10(MSE) + start_target_mse = 10.0 ** ( + (20.0 * math.log10(255.0) - start_target_value) / 10.0 + ) + end_target_mse = 10.0 ** ( + (20.0 * math.log10(255.0) - end_target_value) / 10.0 + ) + + max_steps = ( + getattr(opts, "scheduler.max_iterations", None) + if self.iteration_based_training + else getattr(opts, "scheduler.max_epochs", None) + ) + + if max_steps is None and is_master_node: + logger.error( + "Please specify {}. Got None.".format( + "--scheduler.max-iterations" + if self.iteration_based_training + else "--scheduler.max-epochs" + ) + ) + + curriculum_method = getattr( + opts, "loss.neural_aug.curriculum_method", None + ) + if curriculum_method in CURRICULUMN_METHOD.keys(): + self.target_value = CURRICULUMN_METHOD[curriculum_method]( + start=start_target_mse, end=end_target_mse, period=max_steps + ) + else: + raise NotImplementedError + + self.curriculumn_learning = True + self.target_str = f"[{start_target_value}, {end_target_value}]" + else: + raise NotImplementedError + + # the maximum possible MSE error is computed as: + # a = torch.ones((3, H, W)) * 255.0 # Max. input value is 255.0 + # b = torch.zeros((3, H, W)) # min. input value is 0.0 + # mse = torch.mean( (a -b) ** 2) + + self.alpha = 100.0 / 65025.0 # 65025 is the maximum mse + else: + if is_master_node: + logger.error( + "Supported perceptual metrics are: {}. Got: {}".format( + self.__supported_metrics, perceptual_metric + ) + ) + self.perceptual_metric = perceptual_metric + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + return parser + + def _forward_psnr( + self, input_tensor: Tensor, augmented_tensor: Tensor, *args, **kwargs + ) -> Tensor: + squared_err = ((augmented_tensor - input_tensor) * 255.0) ** 2 + # [B, C, H, W] --> [B] + pred_mse = torch.mean(squared_err, dim=[1, 2, 3]) + + # compute L1 loss between target and current MSE + if self.curriculumn_learning: + step = ( + kwargs.get("iterations", 0) + if self.iteration_based_training + else kwargs.get("epoch", 0) + ) + if step >= len(self.target_value): + step = -1 + target_mse = self.target_value[step] + else: + target_mse = self.target_value + + loss_na = F.smooth_l1_loss( + input=pred_mse, + target=target_mse.expand_as(pred_mse).to( + device=pred_mse.device, dtype=pred_mse.dtype + ), + reduction="mean", + ) + + loss_na = loss_na * self.alpha + return loss_na + + def forward_neural_aug( + self, input_tensor: Tensor, augmented_tensor: Tensor, *args, **kwargs + ) -> Tensor: + + if self.perceptual_metric == "psnr": + loss_na = self._forward_psnr( + input_tensor=input_tensor, + augmented_tensor=augmented_tensor, + *args, + **kwargs, + ) + return loss_na + else: + logger.error( + "Supported perceptual metrics are {}. Got: {}".format( + self.__supported_metrics, self.perceptual_metric + ) + ) + + def repr_na(self): + return ( + "\n\ttarget_metric={}" + "\n\ttarget_value={}" + "\n\tcurriculum_learning={}".format( + self.perceptual_metric, + self.target_str, + self.curriculumn_learning, + ) + ) + + def __repr__(self): + return "{}()".format(self.__class__.__name__) + + +def linear_curriculumn(start, end, period): + """This function implements linear curriculumn""" + return torch.linspace(start=start, end=end, steps=period + 1, dtype=torch.float) + + +def cosine_curriculumn(start, end, period): + """This function implements cosine curriculumn""" + + curr = [ + end + 0.5 * (start - end) * (1 + math.cos(math.pi * i / (period + 1))) + for i in range(period + 1) + ] + + curr = torch.tensor(curr, dtype=torch.float) + return curr + + +CURRICULUMN_METHOD = { + "linear": linear_curriculumn, + "cosine": cosine_curriculumn, +} diff --git a/Adaptive Frequency Filters/loss_fn/classification.py b/Adaptive Frequency Filters/loss_fn/classification.py new file mode 100644 index 0000000..9f20156 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/classification.py @@ -0,0 +1,48 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import Tensor +import argparse +from utils import logger + +from . import BaseCriteria, register_loss_fn +from .classification_loss_fns import get_classification_loss, arguments_cls_loss_fn + + +@register_loss_fn("classification") +class ClassificationLoss(BaseCriteria): + def __init__(self, opts, *args, **kwargs): + super().__init__(opts, *args, **kwargs) + + self.criteria = get_classification_loss(opts=opts, *args, **kwargs) + + def forward( + self, input_sample: Tensor, prediction: Tensor, target: Tensor, *args, **kwargs + ) -> Tensor: + return self.criteria( + input_sample=input_sample, + prediction=prediction, + target=target, + *args, + **kwargs + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--loss.classification.name", + type=str, + default="cross_entropy", + help="Loss function name", + ) + parser = arguments_cls_loss_fn(parser) + return parser + + def __repr__(self): + return self.criteria.__repr__() diff --git a/Adaptive Frequency Filters/loss_fn/classification_loss_fns/__init__.py b/Adaptive Frequency Filters/loss_fn/classification_loss_fns/__init__.py new file mode 100644 index 0000000..626a060 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/classification_loss_fns/__init__.py @@ -0,0 +1,82 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import importlib +import os +import argparse + +from utils import logger + +from ..base_criteria import BaseCriteria + +SUPPORTED_CLS_LOSS_FNS = [] +CLS_LOSS_FN_REGISTRY = {} + + +def register_classification_loss_fn(name): + def register_fn(cls): + if name in SUPPORTED_CLS_LOSS_FNS: + raise ValueError( + "Cannot register duplicate classfication loss function ({})".format( + name + ) + ) + + if not issubclass(cls, BaseCriteria): + raise ValueError( + "Loss function ({}: {}) must extend BaseCriteria".format( + name, cls.__name__ + ) + ) + + CLS_LOSS_FN_REGISTRY[name] = cls + SUPPORTED_CLS_LOSS_FNS.append(name) + return cls + + return register_fn + + +def arguments_cls_loss_fn(parser: argparse.ArgumentParser): + # add loss function specific arguments + for k, v in CLS_LOSS_FN_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + return parser + + +def supported_loss_fn_str(loss_fn_name): + supp_str = ( + "Loss function ({}) is not yet supported. \n Supported functions are:".format( + loss_fn_name + ) + ) + for i, fn_name in enumerate(SUPPORTED_CLS_LOSS_FNS): + supp_str += "{} \t".format(fn_name) + logger.error(supp_str) + + +def get_classification_loss(opts, *args, **kwargs): + loss_fn_name = getattr(opts, "loss.classification.name", "cross_entropy") + + if loss_fn_name in SUPPORTED_CLS_LOSS_FNS: + return CLS_LOSS_FN_REGISTRY[loss_fn_name](opts, *args, **kwargs) + else: + supported_loss_fn_str(loss_fn_name) + return None + + +# automatically import different loss functions +loss_fn_dir = os.path.dirname(__file__) +for file in os.listdir(loss_fn_dir): + path = os.path.join(loss_fn_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module( + "loss_fn.classification_loss_fns." + model_name + ) diff --git a/Adaptive Frequency Filters/loss_fn/classification_loss_fns/binary_cross_entropy.py b/Adaptive Frequency Filters/loss_fn/classification_loss_fns/binary_cross_entropy.py new file mode 100644 index 0000000..1f49524 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/classification_loss_fns/binary_cross_entropy.py @@ -0,0 +1,36 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch.nn import functional as F +from torch import Tensor +import argparse + +from . import register_classification_loss_fn +from .. import BaseCriteria + + +@register_classification_loss_fn(name="binary_cross_entropy") +class ClsBinaryCrossEntropy(BaseCriteria): + """Binary CE for classification tasks""" + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts, *args, **kwargs) + + def forward( + self, input_sample: Tensor, prediction: Tensor, target: Tensor, *args, **kwargs + ) -> Tensor: + if target.dim() != prediction.dim(): + target = F.one_hot(target, num_classes=prediction.shape[-1]) + + return F.binary_cross_entropy_with_logits( + input=prediction, + target=target.to(prediction.dtype), + weight=None, + reduction="sum", + ) + + def __repr__(self) -> str: + return "{}()".format(self.__class__.__name__) diff --git a/Adaptive Frequency Filters/loss_fn/classification_loss_fns/cross_entropy.py b/Adaptive Frequency Filters/loss_fn/classification_loss_fns/cross_entropy.py new file mode 100644 index 0000000..7a75b65 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/classification_loss_fns/cross_entropy.py @@ -0,0 +1,75 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch.nn import functional as F +from torch import Tensor +from typing import Dict +import argparse + +from utils import logger + +from . import register_classification_loss_fn +from .. import BaseCriteria + + +@register_classification_loss_fn(name="cross_entropy") +class ClsCrossEntropy(BaseCriteria): + """Cross entropy for classification tasks""" + + def __init__(self, opts, *args, **kwargs): + ignore_idx = getattr(opts, "loss.ignore_idx", -1) + use_class_wts = getattr( + opts, "loss.classification.cross_entropy.class_weights", False + ) + super().__init__(opts, *args, **kwargs) + + self.ignore_idx = ignore_idx + self.use_class_wts = use_class_wts + self.label_smoothing = getattr(opts, "loss.classification.label_smoothing", 0.0) + + def forward( + self, input_sample: Tensor, prediction: Tensor, target: Tensor, *args, **kwargs + ) -> Tensor: + weight = None + if self.use_class_wts and self.training: + n_classes = prediction.shape[1] + weight = self._class_weights(target=target, n_classes=n_classes) + + return F.cross_entropy( + input=prediction, + target=target, + weight=weight, + ignore_index=self.ignore_idx, + label_smoothing=self.label_smoothing + if self.training + else 0.0, # for validation, compute standard CE loss + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--loss.classification.cross-entropy.class-weights", + action="store_true", + help="Use class weights in loss function", + ) + group.add_argument( + "--loss.classification.label-smoothing", + type=float, + default=0.0, + help="Label smoothing value", + ) + return parser + + def __repr__(self): + return "{}(\n\tignore_idx={}\n\tclass_wts={}\n\tlabel_smoothing={}\n)".format( + self.__class__.__name__, + self.ignore_idx, + self.use_class_wts, + self.label_smoothing, + ) diff --git a/Adaptive Frequency Filters/loss_fn/classification_loss_fns/cross_entropy_with_neural_aug.py b/Adaptive Frequency Filters/loss_fn/classification_loss_fns/cross_entropy_with_neural_aug.py new file mode 100644 index 0000000..f7400da --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/classification_loss_fns/cross_entropy_with_neural_aug.py @@ -0,0 +1,94 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- +from torch import Tensor +from typing import Dict +import argparse + +from utils import logger + +from . import register_classification_loss_fn + +from .cross_entropy import ClsCrossEntropy +from ..base_neural_aug import BaseNeuralAug + + +@register_classification_loss_fn(name="cross_entropy_with_na") +class CrossEntropyWithNA(ClsCrossEntropy, BaseNeuralAug): + """Cross entropy with Perceptual loss for classification tasks with neural augmentation""" + + def __init__(self, opts, *args, **kwargs): + ClsCrossEntropy.__init__(self, opts, *args, **kwargs) + BaseNeuralAug.__init__(self, opts, *args, **kwargs) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + return parser + + def forward( + self, input_sample: Tensor, prediction: Dict, target: Tensor, *args, **kwargs + ) -> Dict[str, Tensor]: + if not isinstance(prediction, Dict): + logger.error( + "Prediction needs to be an instance of Dict and must contain logits and augmented_tensor" + " as keys" + ) + + if not {"augmented_tensor", "logits"}.issubset(prediction.keys()): + logger.error( + "Prediction needs to be an instance of Dict and must contain logits and augmented_tensor" + " as keys. Got keys: {}".format(prediction.keys()) + ) + + augmented_tensor = prediction.get("augmented_tensor", None) + logits = prediction.get("logits", None) + + if augmented_tensor is None: + ce_loss = ClsCrossEntropy.forward( + self, + input_sample=input_sample, + prediction=logits, + target=target, + *args, + **kwargs + ) + return {"total_loss": ce_loss} + + loss_na = self.forward_neural_aug( + input_tensor=input_sample, + augmented_tensor=augmented_tensor, + *args, + **kwargs + ) + + ce_loss = ClsCrossEntropy.forward( + self, + input_sample=augmented_tensor, + prediction=logits, + target=target, + *args, + **kwargs + ) + + return { + "total_loss": loss_na + ce_loss, + "na_loss": loss_na, + "cls_loss": ce_loss, + } + + def __repr__(self): + repr_str = ( + "{}(\n\tignore_idx={}" + "\n\tclass_wts={}" + "\n\tlabel_smoothing={}{}" + "\n)".format( + self.__class__.__name__, + self.ignore_idx, + self.use_class_wts, + self.label_smoothing, + self.repr_na(), + ) + ) + return repr_str diff --git a/Adaptive Frequency Filters/loss_fn/detection.py b/Adaptive Frequency Filters/loss_fn/detection.py new file mode 100644 index 0000000..86a3b85 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/detection.py @@ -0,0 +1,56 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import Tensor +import argparse +from typing import Union + +from . import BaseCriteria, register_loss_fn +from .detection_loss_fns import get_detection_loss, arguments_detection_loss_fn + + +@register_loss_fn("detection") +class DetectionLoss(BaseCriteria): + def __init__(self, opts, *args, **kwargs): + super().__init__(opts, *args, **kwargs) + + self.criteria = get_detection_loss(opts=opts) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--loss.detection.name", + type=str, + default="cross_entropy", + help="Detection loss function name", + ) + + parser = arguments_detection_loss_fn(parser) + return parser + + def forward( + self, + input_sample: Tensor, + prediction: Union[Tensor, Union[Tensor, Tensor]], + target: Tensor, + *args, + **kwargs + ) -> Tensor: + + loss = self.criteria( + input_sample=input_sample, + prediction=prediction, + target=target, + *args, + **kwargs + ) + return loss + + def __repr__(self): + return self.criteria.__repr__() diff --git a/Adaptive Frequency Filters/loss_fn/detection_loss_fns/__init__.py b/Adaptive Frequency Filters/loss_fn/detection_loss_fns/__init__.py new file mode 100644 index 0000000..49abb97 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/detection_loss_fns/__init__.py @@ -0,0 +1,76 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import importlib +import os +import argparse + +from utils import logger + +from ..base_criteria import BaseCriteria + +SUPPORTED_DETECTION_LOSS_FNS = [] +DETECTION_LOSS_FN_REGISTRY = {} + + +def register_detection_loss_fn(name): + def register_fn(cls): + if name in SUPPORTED_DETECTION_LOSS_FNS: + raise ValueError( + "Cannot register duplicate detection loss function ({})".format(name) + ) + + if not issubclass(cls, BaseCriteria): + raise ValueError( + "Loss function ({}: {}) must extend BaseCriteria".format( + name, cls.__name__ + ) + ) + + DETECTION_LOSS_FN_REGISTRY[name] = cls + SUPPORTED_DETECTION_LOSS_FNS.append(name) + return cls + + return register_fn + + +def supported_loss_fn_str(loss_fn_name): + supp_str = "Loss function ({}) is not yet supported. \n Supported functions for detection are:".format( + loss_fn_name + ) + for i, fn_name in enumerate(SUPPORTED_DETECTION_LOSS_FNS): + supp_str += "{} \t".format(fn_name) + logger.error(supp_str) + + +def arguments_detection_loss_fn(parser: argparse.ArgumentParser): + # add loss function specific arguments + for k, v in DETECTION_LOSS_FN_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + return parser + + +def get_detection_loss(opts): + loss_fn_name = getattr(opts, "loss.detection.name", "cross_entropy") + + if loss_fn_name in SUPPORTED_DETECTION_LOSS_FNS: + return DETECTION_LOSS_FN_REGISTRY[loss_fn_name](opts) + else: + supported_loss_fn_str(loss_fn_name) + return None + + +# automatically import different loss functions +loss_fn_dir = os.path.dirname(__file__) +for file in os.listdir(loss_fn_dir): + path = os.path.join(loss_fn_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("loss_fn.detection_loss_fns." + model_name) diff --git a/Adaptive Frequency Filters/loss_fn/detection_loss_fns/mask_rcnn_loss.py b/Adaptive Frequency Filters/loss_fn/detection_loss_fns/mask_rcnn_loss.py new file mode 100644 index 0000000..f6d36f7 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/detection_loss_fns/mask_rcnn_loss.py @@ -0,0 +1,105 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch.nn import functional as F +from torch import Tensor +import argparse +from typing import Tuple, Dict, List + + +from . import register_detection_loss_fn +from .. import BaseCriteria + + +@register_detection_loss_fn(name="mask_rcnn_loss") +class MaskRCNNLoss(BaseCriteria): + """Mask RCNN Loss""" + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--loss.detection.mask-rcnn-loss.classifier-weight", + type=float, + default=1, + help="Weight for classifier.", + ) + group.add_argument( + "--loss.detection.mask-rcnn-loss.box-reg-weight", + type=float, + default=1, + help="Weight for box reg.", + ) + group.add_argument( + "--loss.detection.mask-rcnn-loss.mask-weight", + type=float, + default=1, + help="Weight for mask.", + ) + group.add_argument( + "--loss.detection.mask-rcnn-loss.objectness-weight", + type=float, + default=1, + help="Weight for objectness.", + ) + group.add_argument( + "--loss.detection.mask-rcnn-loss.rpn-box-reg", + type=float, + default=1, + help="Weight for rpn box reg.", + ) + return parser + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts, *args, **kwargs) + + self.classifier_weight = getattr( + opts, "loss.detection.mask_rcnn_loss.classifier_weight" + ) + self.box_reg_weight = getattr( + opts, "loss.detection.mask_rcnn_loss.box_reg_weight" + ) + self.mask_weight = getattr(opts, "loss.detection.mask_rcnn_loss.mask_weight") + self.objectness_weight = getattr( + opts, "loss.detection.mask_rcnn_loss.objectness_weight" + ) + self.rpn_box_reg = getattr(opts, "loss.detection.mask_rcnn_loss.rpn_box_reg") + + def extra_repr(self) -> str: + return ( + f"\n\tclassifier_wt={self.classifier_weight}" + f"\n\tbox_reg_weight={self.box_reg_weight}" + f"\n\tmask_weight={self.mask_weight}" + f"\n\tobjectness_weight={self.objectness_weight}" + f"\n\trpn_box_reg={self.rpn_box_reg}" + ) + + def forward( + self, + input_sample: Dict[str, List], + prediction: Dict[str, Tensor], + *args, + **kwargs, + ) -> Dict[str, Tensor]: + + try: + # Loss is computed inside the Mask RCNN model. Here, we only compute the weighted sum of + # different loss functions. + total_loss = ( + self.classifier_weight * prediction["loss_classifier"] + + self.box_reg_weight * prediction["loss_box_reg"] + + self.mask_weight * prediction["loss_mask"] + + self.objectness_weight * prediction["loss_objectness"] + + self.rpn_box_reg * prediction["loss_rpn_box_reg"] + ) + return {"total_loss": total_loss, **prediction} + except KeyError: + # MaskRCNN doesn't return the loss during validation. + device = input_sample["image"][0].device + return {"total_loss": torch.tensor(0.0, device=device)} diff --git a/Adaptive Frequency Filters/loss_fn/detection_loss_fns/mask_rcnn_loss_with_neural_aug.py b/Adaptive Frequency Filters/loss_fn/detection_loss_fns/mask_rcnn_loss_with_neural_aug.py new file mode 100644 index 0000000..4a0e932 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/detection_loss_fns/mask_rcnn_loss_with_neural_aug.py @@ -0,0 +1,87 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- +import torch +from torch import Tensor +from typing import Dict, List +import argparse + +from utils import logger + +from . import register_detection_loss_fn + +from .mask_rcnn_loss import MaskRCNNLoss +from ..base_neural_aug import BaseNeuralAug + + +@register_detection_loss_fn(name="mask_rcnn_loss_with_na") +class MaskRCNNLossWithNA(MaskRCNNLoss, BaseNeuralAug): + """Mask RCNN loss with neural augmentation""" + + def __init__(self, opts, *args, **kwargs): + MaskRCNNLoss.__init__(self, opts, *args, **kwargs) + BaseNeuralAug.__init__(self, opts, *args, **kwargs) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + return parser + + def forward( + self, + input_sample: Dict[str, List], + prediction: Dict[str, Tensor], + *args, + **kwargs, + ) -> Dict[str, Tensor]: + if not isinstance(prediction, Dict): + logger.error( + "Prediction needs to be an instance of Dict and must contain logits and augmented_tensor" + " as keys" + ) + + augmented_tensor = prediction.pop("augmented_tensor", None) + + if augmented_tensor is None: + loss = MaskRCNNLoss.forward( + self, input_sample=input_sample, prediction=prediction, *args, **kwargs + ) + return loss + + if not isinstance(input_sample, Dict): + logger.error( + "Input is expected as a Dictionary containing atleast image as a key" + ) + + if not {"image"}.issubset(input_sample.keys()): + logger.error( + "Input is expected as a Dictionary containing atleast image as a key. Got: {}".format( + input_sample.keys() + ) + ) + + input_image_sample = input_sample["image"] + if isinstance(input_image_sample, List): + # if its a list of images, stack them + input_image_sample = torch.stack(input_image_sample, dim=0) + + loss_na = self.forward_neural_aug( + input_tensor=input_image_sample, + augmented_tensor=augmented_tensor, + *args, + **kwargs, + ) + + loss = MaskRCNNLoss.forward( + self, input_sample=input_sample, prediction=prediction, *args, **kwargs + ) + + loss["total_loss"] += loss_na + loss["na_loss"] = loss_na + return loss + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(\n{self.extra_repr() + self.repr_na()}\n)".replace( + "\n\n", "\n" + ) diff --git a/Adaptive Frequency Filters/loss_fn/detection_loss_fns/ssd_multibox_loss.py b/Adaptive Frequency Filters/loss_fn/detection_loss_fns/ssd_multibox_loss.py new file mode 100644 index 0000000..56acb14 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/detection_loss_fns/ssd_multibox_loss.py @@ -0,0 +1,190 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch.nn import functional as F +from torch import Tensor +import argparse +from typing import Tuple, Dict + +from utils.tensor_utils import tensor_to_python_float +from utils import logger +from utils.ddp_utils import is_master +from affnet.misc.third_party.ssd_utils import hard_negative_mining + +from . import register_detection_loss_fn +from .utils import sigmoid_focal_loss +from .. import BaseCriteria + +from torchvision.models.detection.faster_rcnn import FasterRCNN + + +@register_detection_loss_fn(name="ssd_multibox_loss") +class SSDLoss(BaseCriteria): + """SSD Loss""" + + def __init__(self, opts, *args, **kwargs): + super().__init__(opts, *args, **kwargs) + self.unscaled_reg_loss = 1e-7 + self.unscaled_conf_loss = 1e-7 + self.neg_pos_ratio = getattr( + opts, "loss.detection.ssd_multibox_loss.neg_pos_ratio", 3 + ) + self.wt_loc = 1.0 + self.curr_iter = 0 + self.max_iter = getattr( + opts, "loss.detection.ssd_multibox_loss.max_monitor_iter", -1 + ) + self.update_inter = getattr( + opts, "loss.detection.ssd_multibox_loss.update_wt_freq", 200 + ) + self.is_distributed = getattr(opts, "ddp.use_distributed", False) + self.is_master = is_master(opts) + self.label_smoothing = getattr( + opts, "loss.detection.ssd_multibox_loss.label_smoothing", 0.0 + ) + if not (0.0 <= self.label_smoothing < 1.0): + logger.error( + "The value of --loss.detection.ssd-multibox-loss.label-smoothing should be between 0 and 1. " + "Got: {}".format(self.label_smoothing) + ) + + self.reset_unscaled_loss_values() + + def reset_unscaled_loss_values(self): + # initialize with very small float values + self.unscaled_conf_loss = 1e-7 + self.unscaled_reg_loss = 1e-7 + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--loss.detection.ssd-multibox-loss.neg-pos-ratio", + type=int, + default=3, + help="Negative positive ratio in SSD loss", + ) + group.add_argument( + "--loss.detection.ssd-multibox-loss.max-monitor-iter", + type=int, + default=-1, + help="Number of iterations for monitoring location and classification loss.", + ) + group.add_argument( + "--loss.detection.ssd-multibox-loss.update-wt-freq", + type=int, + default=200, + help="Update the weights after N number of iterations", + ) + group.add_argument( + "--loss.detection.ssd-multibox-loss.label-smoothing", + type=float, + default=0.0, + help="Label smoothing for classification labels in SSD", + ) + return parser + + def __repr__(self): + return "{}(\n\tneg_pos_ratio={}\n\tbox_loss=SmoothL1\n\tclass_loss=CrossEntropy\n\twt_loss={}\n)".format( + self.__class__.__name__, + self.neg_pos_ratio, + True if self.max_iter > 0 else False, + ) + + def _forward_detection_loss( + self, prediction: Dict, target: Dict, *args, **kwargs + ) -> Dict[str, Tensor]: + # confidence: (batch_size, num_priors, num_classes) + # predicted_locations :(batch_size, num_priors, 4) + + confidence = prediction["scores"] + predicted_locations = prediction["boxes"] + + gt_labels = target["box_labels"] + gt_locations = target["box_coordinates"] + + num_classes = confidence.shape[-1] + num_coordinates = predicted_locations.shape[-1] + + pos_mask = gt_labels > 0 + predicted_locations = predicted_locations[pos_mask].reshape(-1, num_coordinates) + gt_locations = gt_locations[pos_mask].reshape(-1, num_coordinates) + num_pos = max(1, gt_locations.shape[0]) + smooth_l1_loss = F.smooth_l1_loss( + predicted_locations, gt_locations, reduction="sum" + ) + + with torch.no_grad(): + loss = -F.log_softmax(confidence, dim=2)[:, :, 0] + mask = hard_negative_mining(loss, gt_labels, self.neg_pos_ratio) + + confidence = confidence[mask, :] + label_smoothing = self.label_smoothing if self.training else 0.0 + classification_loss = F.cross_entropy( + input=confidence.reshape(-1, num_classes), + target=gt_labels[mask], + reduction="sum", + label_smoothing=label_smoothing, + ) + + if self.curr_iter <= self.max_iter and self.training: + # classification loss may dominate localization loss or vice-versa + # therefore, to ensure that their contributions are equal towards total loss, we scale regression loss. + # if classification loss contribution is less (or more), then scaling factor will be < 1 ( > 1) + self.unscaled_conf_loss += tensor_to_python_float( + classification_loss, is_distributed=self.is_distributed + ) + self.unscaled_reg_loss += tensor_to_python_float( + smooth_l1_loss, is_distributed=self.is_distributed + ) + + if ( + self.curr_iter + 1 + ) % self.update_inter == 0 or self.curr_iter == self.max_iter: + before_update = round( + tensor_to_python_float( + self.wt_loc, is_distributed=self.is_distributed + ), + 4, + ) + self.wt_loc = self.unscaled_conf_loss / self.unscaled_reg_loss + self.reset_unscaled_loss_values() + + if self.is_master: + after_update = round( + tensor_to_python_float( + self.wt_loc, is_distributed=self.is_distributed + ), + 4, + ) + logger.log( + f"Updating localization loss multiplier from {before_update} to {after_update}" + ) + + self.curr_iter += 1 + + if self.training and self.wt_loc > 0.0: + smooth_l1_loss = smooth_l1_loss * self.wt_loc + + return { + "total_loss": (smooth_l1_loss + classification_loss) / num_pos, + "reg_loss": smooth_l1_loss / num_pos, + "cls_loss": classification_loss / num_pos, + } + + def forward( + self, input_sample: Tensor, prediction: Dict, target: Dict, *args, **kwargs + ) -> Dict[str, Tensor]: + # confidence: (batch_size, num_priors, num_classes) + # predicted_locations :(batch_size, num_priors, 4) + + detection_loss = self._forward_detection_loss( + prediction=prediction, target=target + ) + return detection_loss diff --git a/Adaptive Frequency Filters/loss_fn/detection_loss_fns/utils.py b/Adaptive Frequency Filters/loss_fn/detection_loss_fns/utils.py new file mode 100644 index 0000000..1a48af1 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/detection_loss_fns/utils.py @@ -0,0 +1,58 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch + + +try: + from torchvision.ops import sigmoid_focal_loss +except ModuleNotFoundError: + # copied from torchvision to ensure that code runs with older versions of PyTorch and Torchvision + from torch.nn import functional as F + + def sigmoid_focal_loss( + inputs: torch.Tensor, + targets: torch.Tensor, + alpha: float = 0.25, + gamma: float = 2, + reduction: str = "none", + ): + """ + Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py . + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples or -1 for ignore. Default = 0.25 + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + reduction: 'none' | 'mean' | 'sum' + 'none': No reduction will be applied to the output. + 'mean': The output will be averaged. + 'sum': The output will be summed. + Returns: + Loss tensor with the reduction option applied. + """ + p = torch.sigmoid(inputs) + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = p * targets + (1 - p) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + if reduction == "mean": + loss = loss.mean() + elif reduction == "sum": + loss = loss.sum() + + return loss diff --git a/Adaptive Frequency Filters/loss_fn/distillation.py b/Adaptive Frequency Filters/loss_fn/distillation.py new file mode 100644 index 0000000..4c47dad --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/distillation.py @@ -0,0 +1,53 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import Tensor +import argparse +from utils import logger + +from . import BaseCriteria, register_loss_fn +from .distillation_loss_fns import get_distillation_loss, arguments_distill_loss_fn + + +@register_loss_fn("distillation") +class DistillationLoss(BaseCriteria): + def __init__(self, opts, *args, **kwargs): + loss_fn_name = getattr(opts, "loss.distillation.name", "vanilla") + super().__init__(opts, *args, **kwargs) + self.criteria = get_distillation_loss(opts=opts, *args, **kwargs) + + def forward( + self, input_sample: Tensor, prediction: Tensor, target: Tensor, *args, **kwargs + ) -> Tensor: + return self.criteria( + input_sample=input_sample, + prediction=prediction, + target=target, + *args, + **kwargs + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--loss.distillation.name", + type=str, + default="vanilla", + help="Distillation loss function name", + ) + parser = arguments_distill_loss_fn(parser=parser) + return parser + + def extra_repr(self) -> str: + if hasattr(self.criteria, "extra_repr"): + return self.criteria.extra_repr() + return "" + + def __repr__(self): + return "{}({}\n)".format(self.criteria.__class__.__name__, self.extra_repr()) diff --git a/Adaptive Frequency Filters/loss_fn/distillation_loss_fns/__init__.py b/Adaptive Frequency Filters/loss_fn/distillation_loss_fns/__init__.py new file mode 100644 index 0000000..36392b2 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/distillation_loss_fns/__init__.py @@ -0,0 +1,68 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import argparse +import importlib + +from utils import logger + +SUPPORTED_DISTILL_LOSS_FNS = [] +DISTILL_LOSS_FN_REGISTRY = {} + + +def register_distillation_loss_fn(name): + def register_fn(fn): + if name in SUPPORTED_DISTILL_LOSS_FNS: + raise ValueError( + "Cannot register duplicate distillation loss function ({})".format(name) + ) + SUPPORTED_DISTILL_LOSS_FNS.append(name) + DISTILL_LOSS_FN_REGISTRY[name] = fn + return fn + + return register_fn + + +def arguments_distill_loss_fn(parser: argparse.ArgumentParser): + # add loss function specific arguments + for k, v in DISTILL_LOSS_FN_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + return parser + + +def supported_loss_fn_str(loss_fn_name): + supp_str = ( + "Loss function ({}) is not yet supported. \n Supported functions are:".format( + loss_fn_name + ) + ) + for i, fn_name in enumerate(SUPPORTED_DISTILL_LOSS_FNS): + supp_str += "{} \t".format(fn_name) + logger.error(supp_str) + + +def get_distillation_loss(opts, *args, **kwargs): + loss_fn_name = getattr(opts, "loss.distillation.name", None) + + if loss_fn_name in SUPPORTED_DISTILL_LOSS_FNS: + return DISTILL_LOSS_FN_REGISTRY[loss_fn_name](opts, *args, **kwargs) + else: + supported_loss_fn_str(loss_fn_name) + return None + + +# automatically import different loss functions +loss_fn_dir = os.path.dirname(__file__) +for file in os.listdir(loss_fn_dir): + path = os.path.join(loss_fn_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("loss_fn.distillation_loss_fns." + model_name) diff --git a/Adaptive Frequency Filters/loss_fn/distillation_loss_fns/cls_kl_div_loss.py b/Adaptive Frequency Filters/loss_fn/distillation_loss_fns/cls_kl_div_loss.py new file mode 100644 index 0000000..682dc38 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/distillation_loss_fns/cls_kl_div_loss.py @@ -0,0 +1,159 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os.path +from torch.nn import functional as F +import torch +from torch import nn, Tensor +import argparse +from typing import Dict, Union + +from . import register_distillation_loss_fn +from .. import BaseCriteria + +from .utils import build_cls_teacher_from_opts + + +@register_distillation_loss_fn(name="cls_kl_div_loss") +class ClsKLDivLoss(BaseCriteria): + """ + KL Loss for classification + """ + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts, *args, **kwargs) + + self.teacher = build_cls_teacher_from_opts(opts=opts) + self.temperature = getattr( + opts, "loss.distillation.cls_kl_div_loss.temperature", 1.0 + ) + self.distillation_mode = getattr( + opts, "loss.distillation.cls_kl_div_loss.mode", "soft" + ) + self.topk = getattr(opts, "loss.distillation.cls_kl_div_loss.topk", 1) + self.label_smoothing = getattr( + opts, "loss.distillation.cls_kl_div_loss.label-smoothing", 0.0 + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--loss.distillation.cls-kl-div-loss.temperature", + type=float, + default=1.0, + help="Temperature for KL Div. loss", + ) + group.add_argument( + "--loss.distillation.cls-kl-div-loss.mode", + type=str, + default="soft", + help="Distillation mode", + ) + group.add_argument( + "--loss.distillation.cls-kl-div-loss.topk", + type=int, + default=1, + help="Distill top-k labels from teacher when using hard-labels", + ) + group.add_argument( + "--loss.distillation.cls-kl-div-loss.label-smoothing", + type=float, + default=0.0, + help="Use label smoothing when using hard-labels", + ) + return parser + + def extra_repr(self) -> str: + extra_repr_str = ( + f"\n\ttemperature={self.temperature}" f"\n\tmode={self.distillation_mode}" + ) + if self.distillation_mode.find("hard") > -1: + extra_repr_str += ( + f"\n\ttopk={self.topk}" f"\n\tlabel_smoothing={self.label_smoothing}" + ) + return extra_repr_str + + def _forward_soft_labels( + self, prediction: Tensor, teacher_logits: Tensor + ) -> Tensor: + with torch.no_grad(): + teacher_lprobs = F.log_softmax( + teacher_logits / self.temperature, dim=1 + ).detach() + + student_lprobs = F.log_softmax(prediction / self.temperature, dim=-1) + kl_loss = F.kl_div( + student_lprobs, teacher_lprobs, reduction="batchmean", log_target=True + ) + return kl_loss * (self.temperature**2) + + def _forward_hard_labels( + self, prediction: Tensor, teacher_logits: Tensor + ) -> Tensor: + with torch.no_grad(): + teacher_probs = F.softmax(teacher_logits, dim=-1).detach() + _, teacher_topk_labels = torch.topk( + teacher_probs, k=self.topk, dim=-1, largest=True, sorted=True + ) + + if self.topk > 1: + num_classes = prediction.shape[-1] + teacher_topk_labels = F.one_hot( + teacher_topk_labels, num_classes=num_classes + ) + teacher_topk_labels = teacher_topk_labels.sum(1) + teacher_topk_labels = teacher_topk_labels.to(dtype=prediction.dtype) + + # smooth labels corresponding to multiple classes + smooth_class_p = (1.0 - self.label_smoothing) / self.topk + # distribute the mass over remaining classes + smooth_non_class_p = self.label_smoothing / (num_classes - self.topk) + + teacher_topk_labels = torch.where( + teacher_topk_labels == 1.0, smooth_class_p, smooth_non_class_p + ) + + # scale by number of classes. Otherwise, the contribution is small + loss = ( + F.binary_cross_entropy_with_logits( + input=prediction, target=teacher_topk_labels, reduction="mean" + ) + * num_classes + ) + else: + teacher_topk_labels = teacher_topk_labels.reshape(-1) + loss = F.cross_entropy( + input=prediction, + target=teacher_topk_labels, + reduction="mean", + label_smoothing=self.label_smoothing, + ) + return loss + + def forward( + self, input_sample: Tensor, prediction: Tensor, target: Tensor, *args, **kwargs + ) -> Tensor: + + with torch.no_grad(): + self.teacher.eval() + teacher_logits: Union[Tensor, Dict] = self.teacher(input_sample) + # Dict in case of neural aug + if isinstance(teacher_logits, Dict): + teacher_logits = teacher_logits["logits"] + + if self.distillation_mode == "soft": + return self._forward_soft_labels( + prediction=prediction, teacher_logits=teacher_logits + ) + elif self.distillation_mode == "hard": + return self._forward_hard_labels( + prediction=prediction, teacher_logits=teacher_logits + ) + else: + raise NotImplementedError diff --git a/Adaptive Frequency Filters/loss_fn/distillation_loss_fns/cls_kl_div_loss_neural_aug.py b/Adaptive Frequency Filters/loss_fn/distillation_loss_fns/cls_kl_div_loss_neural_aug.py new file mode 100644 index 0000000..d7a6ab8 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/distillation_loss_fns/cls_kl_div_loss_neural_aug.py @@ -0,0 +1,104 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch.nn import functional as F +from torch import Tensor +from typing import Tuple, Union, Dict +import argparse + +from utils import logger + +from . import register_distillation_loss_fn +from .cls_kl_div_loss import ClsKLDivLoss + +from ..base_neural_aug import BaseNeuralAug + + +@register_distillation_loss_fn(name="cls_kl_div_loss_with_na") +class ClsKLDivLossWithNA(ClsKLDivLoss, BaseNeuralAug): + """ + KLDiv loss with Perceptual loss for distillation + """ + + def __init__(self, opts, *args, **kwargs): + BaseNeuralAug.__init__(self, opts, *args, **kwargs) + ClsKLDivLoss.__init__(self, opts, *args, **kwargs) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + return parser + + def extra_repr(self) -> str: + return super().extra_repr() + self.repr_na() + + def forward( + self, + input_sample: Tensor, + prediction: Union[Dict, Tensor], + target: Tensor, + *args, + **kwargs + ) -> Dict: + + if isinstance(prediction, Tensor): + kl_loss = super().forward( + input_sample=input_sample, + prediction=prediction, + target=target, + *args, + **kwargs + ) + return {"total_loss": kl_loss} + elif isinstance(prediction, Dict): + if not isinstance(prediction, Dict): + logger.error( + "Prediction needs to be an instance of Dict and must contain logits and augmented_tensor" + " as keys" + ) + + if not {"augmented_tensor", "logits"}.issubset(prediction.keys()): + logger.error( + "Prediction needs to be an instance of Dict and must contain logits and augmented_tensor" + " as keys. Got keys: {}".format(prediction.keys()) + ) + + augmented_tensor = prediction.get("augmented_tensor", None) + logits = prediction.get("logits", None) + + if augmented_tensor is None: + kl_loss = ClsKLDivLoss.forward( + self, + input_sample=input_sample, + prediction=logits, + target=target, + *args, + **kwargs + ) + return {"total_loss": kl_loss} + + kl_loss = ClsKLDivLoss.forward( + self, + input_sample=augmented_tensor, + prediction=logits, + target=target, + *args, + **kwargs + ) + + loss_na = self.forward_neural_aug( + input_tensor=input_sample, + augmented_tensor=augmented_tensor, + *args, + **kwargs + ) + + return { + "total_loss": loss_na + kl_loss, + "na_loss": loss_na, + "kl_loss": kl_loss, + } + else: + raise NotImplementedError diff --git a/Adaptive Frequency Filters/loss_fn/distillation_loss_fns/utils.py b/Adaptive Frequency Filters/loss_fn/distillation_loss_fns/utils.py new file mode 100644 index 0000000..8e163c8 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/distillation_loss_fns/utils.py @@ -0,0 +1,37 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse +from torch import nn + +from affnet.models.classification import build_classification_model +from utils import logger + + +def build_cls_teacher_from_opts(opts) -> nn.Module: + """ + Helper function to build a classification teacher model from options + """ + pretrained_model = getattr(opts, "teacher.model.classification.pretrained", None) + if not pretrained_model: + logger.error( + "For distillation, please specify teacher weights using teacher.model.classification.pretrained" + ) + + opts_dict = vars(opts) + teacher_dict = { + # replace teacher with empty string in "teacher.model.*" to get model.* + key.replace("teacher.", ""): value + for key, value in opts_dict.items() + # filter keys related to teacher + if key.split(".")[0] == "teacher" + } + + # convert to Namespace + teacher_opts = argparse.Namespace(**teacher_dict) + + # build teacher model + return build_classification_model(teacher_opts) diff --git a/Adaptive Frequency Filters/loss_fn/multi_modal_img_text.py b/Adaptive Frequency Filters/loss_fn/multi_modal_img_text.py new file mode 100644 index 0000000..238d2ac --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/multi_modal_img_text.py @@ -0,0 +1,47 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import Tensor +import argparse +from typing import Any +from utils import logger + +from . import BaseCriteria, register_loss_fn +from .multi_modal_img_text_loss_fns import ( + get_multi_modal_img_text_loss, + arguments_multi_modal_img_text_loss_fn, +) + + +@register_loss_fn("multi_modal_image_text") +class MultiModalImageTextLoss(BaseCriteria): + def __init__(self, opts, *args, **kwargs): + super().__init__() + self.criteria = get_multi_modal_img_text_loss(opts=opts, *args, **kwargs) + + def forward( + self, input_sample: Any, prediction: Any, target: Any, *args, **kwargs + ) -> Any: + return self.criteria( + input_sample=input_sample, prediction=prediction, target=target + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--loss.multi-modal-image-text.name", + type=str, + default="clip", + help="Loss function name", + ) + parser = arguments_multi_modal_img_text_loss_fn(parser) + return parser + + def __repr__(self): + return self.criteria.__repr__() diff --git a/Adaptive Frequency Filters/loss_fn/multi_modal_img_text_loss_fns/__init__.py b/Adaptive Frequency Filters/loss_fn/multi_modal_img_text_loss_fns/__init__.py new file mode 100644 index 0000000..a328a8e --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/multi_modal_img_text_loss_fns/__init__.py @@ -0,0 +1,81 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import importlib +import os +import argparse + +from utils import logger + +from ..base_criteria import BaseCriteria + +SUPPORTED_MULTI_MODAL_IMG_TEXT_LOSS_FNS = [] +MULTI_MODAL_IMG_TEXT_LOSS_FN_REGISTRY = {} + + +def register_multi_modal_img_text_loss_fns(name): + def register_fn(cls): + if name in SUPPORTED_MULTI_MODAL_IMG_TEXT_LOSS_FNS: + raise ValueError( + "Cannot register duplicate multi-modal image-text loss function ({})".format( + name + ) + ) + + if not issubclass(cls, BaseCriteria): + raise ValueError( + "Loss function ({}: {}) must extend BaseCriteria".format( + name, cls.__name__ + ) + ) + + MULTI_MODAL_IMG_TEXT_LOSS_FN_REGISTRY[name] = cls + SUPPORTED_MULTI_MODAL_IMG_TEXT_LOSS_FNS.append(name) + return cls + + return register_fn + + +def arguments_multi_modal_img_text_loss_fn(parser: argparse.ArgumentParser): + # add loss function specific arguments + for k, v in MULTI_MODAL_IMG_TEXT_LOSS_FN_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + return parser + + +def supported_loss_fn_str(loss_fn_name): + supp_str = ( + "Loss function ({}) is not yet supported. \n Supported functions are:".format( + loss_fn_name + ) + ) + for i, fn_name in enumerate(SUPPORTED_MULTI_MODAL_IMG_TEXT_LOSS_FNS): + supp_str += "{} \t".format(fn_name) + logger.error(supp_str) + + +def get_multi_modal_img_text_loss(opts, *args, **kwargs): + loss_name = getattr(opts, "loss.multi_modal_image_text.name", None) + + if loss_name in SUPPORTED_MULTI_MODAL_IMG_TEXT_LOSS_FNS: + return MULTI_MODAL_IMG_TEXT_LOSS_FN_REGISTRY[loss_name](opts, *args, **kwargs) + else: + supported_loss_fn_str(loss_name) + + +# automatically import different loss functions +loss_fn_dir = os.path.dirname(__file__) +for file in os.listdir(loss_fn_dir): + path = os.path.join(loss_fn_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + loss_fn_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module( + "loss_fn.multi_modal_img_text_loss_fns." + loss_fn_name + ) diff --git a/Adaptive Frequency Filters/loss_fn/multi_modal_img_text_loss_fns/contrastive_loss_clip.py b/Adaptive Frequency Filters/loss_fn/multi_modal_img_text_loss_fns/contrastive_loss_clip.py new file mode 100644 index 0000000..6ab22f9 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/multi_modal_img_text_loss_fns/contrastive_loss_clip.py @@ -0,0 +1,102 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- +import argparse +from typing import Any, Tuple, Dict + +import torch +from torch import Tensor +from torch.nn import functional as F + +from utils.tensor_utils import gather_all_features +from . import BaseCriteria, register_multi_modal_img_text_loss_fns + + +@register_multi_modal_img_text_loss_fns(name="contrastive_loss_clip") +class ContrastiveLossClip(BaseCriteria): + """CLIP Loss function for multi-modal image-text training""" + + def __init__(self, opts, *args, **kwargs) -> None: + super().__init__(opts, *args, **kwargs) + self.rank = getattr(opts, "ddp.rank", 0) + self.use_distributed = getattr(opts, "ddp.use_distributed", False) + self.device = getattr(opts, "dev.device", torch.device("cpu")) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + return parser + + def forward( + self, input_sample: Any, prediction: Dict, target: Any, *args, **kwargs + ) -> Dict[str, Tensor]: + + image_features = prediction.pop("image", None) + text_features = prediction.pop("text", None) + + if image_features is None or text_features is None: + # if either image features or text features is None, then loss can't be computed. + # simply return 0.0 + return { + "total_loss": torch.tensor(0.0, dtype=torch.float, device=self.device), + } + + assert image_features is not None + assert text_features is not None + + logit_scale = prediction.pop("logit_scale", 1.0) + + # we need to aggregate + gathered_image_features, gathered_text_features = gather_features( + image_features=image_features, + text_features=text_features, + use_distributed=self.use_distributed, + ) + # compute logits + # [B, d] x [BW x d]^T --> [B, BW] + logits_per_image = logit_scale * ( + image_features @ gathered_text_features.transpose(0, 1) + ) + # [B, d] x [BW, d]^T --> [B, BW] + logits_per_text = logit_scale * ( + text_features @ gathered_image_features.transpose(0, 1) + ) + + # generate labels + num_logits = logits_per_image.shape[0] + contrastive_labels = torch.arange( + num_logits, device=logits_per_image.device, dtype=torch.long + ) + + # shift the labels by rank id + contrastive_labels = contrastive_labels + (num_logits * self.rank) + + text_loss = F.cross_entropy(logits_per_text, contrastive_labels) * 0.5 + image_loss = F.cross_entropy(logits_per_image, contrastive_labels) * 0.5 + total_loss = image_loss + text_loss + return { + "total_loss": total_loss, + "image_loss": image_loss, + "text_loss": text_loss, + "logit_scale": logit_scale, + } + + def __repr__(self) -> str: + return "{}()".format(self.__class__.__name__) + + +def gather_features( + image_features: Tensor, text_features: Tensor, use_distributed: bool +) -> Tuple[Tensor, Tensor]: + """ + Helper function that allows us to gather image and text features from all DDP ranks in a differentiable manner + """ + if use_distributed: + # gather features from all ranks + # [B, d] x W --> [BW, d] where W is the world size + gathered_image_features = gather_all_features(features=image_features, dim=0) + # [B, d] x W --> [BW, d] where W is the world size + gathered_text_features = gather_all_features(features=text_features, dim=0) + return gathered_image_features, gathered_text_features + return image_features, text_features diff --git a/Adaptive Frequency Filters/loss_fn/multi_modal_img_text_loss_fns/contrastive_loss_clip_with_neural_aug.py b/Adaptive Frequency Filters/loss_fn/multi_modal_img_text_loss_fns/contrastive_loss_clip_with_neural_aug.py new file mode 100644 index 0000000..df42705 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/multi_modal_img_text_loss_fns/contrastive_loss_clip_with_neural_aug.py @@ -0,0 +1,67 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from typing import Dict, Any +import argparse + +from . import register_multi_modal_img_text_loss_fns +from .contrastive_loss_clip import ContrastiveLossClip + +from ..base_neural_aug import BaseNeuralAug + + +@register_multi_modal_img_text_loss_fns(name="contrastive_loss_clip_with_na") +class ContrastiveLossClipWithNA(ContrastiveLossClip, BaseNeuralAug): + """CLIP Loss function for multi-modal image-text training with neural augmentation""" + + def __init__(self, opts, *args, **kwargs): + ContrastiveLossClip.__init__(self, opts, *args, **kwargs) + BaseNeuralAug.__init__(self, opts, *args, **kwargs) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + return parser + + def forward( + self, input_sample: Dict, prediction: Dict, target: Any, *args, **kwargs + ) -> Dict: + + augmented_tensor = prediction.get("augmented_tensor") + if augmented_tensor is None: + return ContrastiveLossClip.forward( + self, + input_sample=input_sample, + prediction=prediction, + target=target, + *args, + **kwargs + ) + elif "augmented_tensor" in prediction and "image" in input_sample: + contrastive_loss = ContrastiveLossClip.forward( + self, + input_sample=input_sample, + prediction=prediction, + target=target, + *args, + **kwargs + ) + + loss_na = self.forward_neural_aug( + input_tensor=input_sample["image"], + augmented_tensor=augmented_tensor, + *args, + **kwargs + ) + contrastive_loss["total_loss"] = ( + contrastive_loss.pop("total_loss") + loss_na + ) + contrastive_loss["na_loss"] = loss_na + return contrastive_loss + else: + raise NotImplementedError + + def __repr__(self): + return "{}({}\n)".format(self.__class__.__name__, self.repr_na()) diff --git a/Adaptive Frequency Filters/loss_fn/segmentation.py b/Adaptive Frequency Filters/loss_fn/segmentation.py new file mode 100644 index 0000000..35f2048 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/segmentation.py @@ -0,0 +1,48 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import Tensor +import argparse +from utils import logger +from typing import Any + +from . import BaseCriteria, register_loss_fn +from .segmentation_loss_fns import get_segmentation_loss, arguments_seg_loss_fn + + +@register_loss_fn("segmentation") +class SegmentationLoss(BaseCriteria): + def __init__(self, opts, *args, **kwargs): + super().__init__(opts, *args, **kwargs) + self.criteria = get_segmentation_loss(opts=opts, *args, **kwargs) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--loss.segmentation.name", + type=str, + default="cross_entropy", + help="Segmentation loss function name", + ) + parser = arguments_seg_loss_fn(parser=parser) + return parser + + def forward( + self, input_sample: Any, prediction: Any, target: Any, *args, **kwargs + ) -> Tensor: + return self.criteria( + input_sample=input_sample, + prediction=prediction, + target=target, + *args, + **kwargs + ) + + def __repr__(self): + return self.criteria.__repr__() diff --git a/Adaptive Frequency Filters/loss_fn/segmentation_loss_fns/__init__.py b/Adaptive Frequency Filters/loss_fn/segmentation_loss_fns/__init__.py new file mode 100644 index 0000000..3289e5b --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/segmentation_loss_fns/__init__.py @@ -0,0 +1,78 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import importlib +import argparse + +from utils import logger + +from ..base_criteria import BaseCriteria + +SUPPORTED_SEG_LOSS_FNS = [] +SEG_LOSS_FN_REGISTRY = {} + + +def register_segmentation_loss_fn(name): + def register_fn(cls): + if name in SUPPORTED_SEG_LOSS_FNS: + raise ValueError( + "Cannot register duplicate segmentation loss function ({})".format(name) + ) + + if not issubclass(cls, BaseCriteria): + raise ValueError( + "Loss function ({}: {}) must extend BaseCriteria".format( + name, cls.__name__ + ) + ) + + SUPPORTED_SEG_LOSS_FNS.append(name) + SEG_LOSS_FN_REGISTRY[name] = cls + return cls + + return register_fn + + +def supported_loss_fn_str(loss_fn_name): + supp_str = ( + "Loss function ({}) is not yet supported. \n Supported functions are:".format( + loss_fn_name + ) + ) + for i, fn_name in enumerate(SUPPORTED_SEG_LOSS_FNS): + supp_str += "{} \t".format(fn_name) + logger.error(supp_str) + + +def get_segmentation_loss(opts, *args, **kwargs): + loss_fn_name = getattr(opts, "loss.segmentation.name", "cross_entropy") + + if loss_fn_name in SUPPORTED_SEG_LOSS_FNS: + return SEG_LOSS_FN_REGISTRY[loss_fn_name](opts, *args, **kwargs) + else: + supported_loss_fn_str(loss_fn_name) + return None + + +def arguments_seg_loss_fn(parser: argparse.ArgumentParser): + # add loss function specific arguments + for k, v in SEG_LOSS_FN_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + return parser + + +# automatically import different loss functions +loss_fn_dir = os.path.dirname(__file__) +for file in os.listdir(loss_fn_dir): + path = os.path.join(loss_fn_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("loss_fn.segmentation_loss_fns." + model_name) diff --git a/Adaptive Frequency Filters/loss_fn/segmentation_loss_fns/cross_entropy.py b/Adaptive Frequency Filters/loss_fn/segmentation_loss_fns/cross_entropy.py new file mode 100644 index 0000000..d3c1048 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/segmentation_loss_fns/cross_entropy.py @@ -0,0 +1,131 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch.nn import functional as F +from torch import Tensor +from typing import Tuple, Union, Optional +import argparse + +from . import register_segmentation_loss_fn +from .. import BaseCriteria + + +@register_segmentation_loss_fn(name="cross_entropy") +class SegCrossEntropy(BaseCriteria): + """Cross entropy loss for the task of semantic segmentation""" + + def __init__(self, opts, *args, **kwargs): + super().__init__(opts, *args, **kwargs) + self.ignore_idx = getattr(opts, "loss.ignore_idx", -1) + self.weighted_loss = getattr( + opts, "loss.segmentation.cross_entropy.class_weights", False + ) + self.aux_wt = getattr(opts, "loss.segmentation.cross_entropy.aux_weight", 0.4) + self.label_smoothing = getattr( + opts, "loss.segmentation.cross_entropy.label_smoothing", 0.0 + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="".format(cls.__name__), description="".format(cls.__name__) + ) + group.add_argument( + "--loss.segmentation.cross-entropy.class-weights", + action="store_true", + help="Use class weights in loss function", + ) + group.add_argument( + "--loss.segmentation.cross-entropy.aux-weight", + type=float, + default=0.4, + help="Weight of auxiliary loss", + ) + group.add_argument( + "--loss.segmentation.cross-entropy.label-smoothing", + type=float, + default=0.0, + help="Label smoothing in CE loss for the task of segmentation", + ) + + return parser + + def _compute_loss( + self, pred_mask: Tensor, target_mask: Tensor, weight: Optional[Tensor] = None + ): + b, c, x_h, x_w = pred_mask.shape + b, y_h, y_w = target_mask.shape + + # use label smoothing only for training + label_smoothing = self.label_smoothing if self.training else 0.0 + + if x_h != y_h or x_w != y_w: + pred_mask = F.interpolate( + pred_mask, size=(y_h, y_w), mode="bilinear", align_corners=True + ) + + loss = F.cross_entropy( + input=pred_mask, + target=target_mask, + weight=weight, + ignore_index=self.ignore_idx, + label_smoothing=label_smoothing, + ) + + return loss + + def forward( + self, + input_sample: Tensor, + prediction: Union[Tensor, Tuple[Tensor, Tensor]], + target: Tensor, + *args, + **kwargs + ) -> Tensor: + aux_out = None + if isinstance(prediction, Tuple) and len(prediction) == 2: + mask, aux_out = prediction + assert isinstance(mask, Tensor) + assert isinstance(aux_out, Tensor) + elif isinstance(prediction, Tensor): + mask = prediction + assert isinstance(mask, Tensor) + else: + raise NotImplementedError( + "For computing loss for segmentation task, we need prediction to be an instance of Tuple or Tensor" + ) + + cls_wts = None + if self.training: + if self.weighted_loss: + n_classes = mask.size(1) # Mask is of shape B x C x H x W + cls_wts = self._class_weights(target=target, n_classes=n_classes) + total_loss = self._compute_loss( + pred_mask=mask, target_mask=target, weight=cls_wts + ) + + if aux_out is not None: + loss_aux = self._compute_loss( + pred_mask=aux_out, target_mask=target, weight=cls_wts + ) + total_loss = total_loss + (self.aux_wt * loss_aux) + return total_loss + else: + return self._compute_loss(pred_mask=mask, target_mask=target, weight=None) + + def __repr__(self): + repr_str = ( + "{}(\n\tweighted_loss={}\n\tignore_idx={}\n\tlabel_smoothing={}".format( + self.__class__.__name__, + self.weighted_loss, + self.ignore_idx, + self.label_smoothing, + ) + ) + + if self.aux_wt > 0: + repr_str += "\n\taux_wt={}".format(self.aux_wt) + return repr_str + "\n)" diff --git a/Adaptive Frequency Filters/loss_fn/segmentation_loss_fns/seg_cross_entropy_with_neural_aug.py b/Adaptive Frequency Filters/loss_fn/segmentation_loss_fns/seg_cross_entropy_with_neural_aug.py new file mode 100644 index 0000000..2524321 --- /dev/null +++ b/Adaptive Frequency Filters/loss_fn/segmentation_loss_fns/seg_cross_entropy_with_neural_aug.py @@ -0,0 +1,108 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch.nn import functional as F +from torch import Tensor +from typing import Tuple, Union, Dict +import argparse + +from utils import logger + +from . import register_segmentation_loss_fn +from .cross_entropy import SegCrossEntropy + +from ..base_neural_aug import BaseNeuralAug + + +@register_segmentation_loss_fn(name="seg_cross_entropy_with_na") +class SegCrossEntropyWithNA(SegCrossEntropy, BaseNeuralAug): + """Cross entropy with Perceptual loss for segmentation tasks with neural augmentation""" + + def __init__(self, opts, *args, **kwargs): + SegCrossEntropy.__init__(self, opts, *args, **kwargs) + BaseNeuralAug.__init__(self, opts, *args, **kwargs) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + return parser + + def forward( + self, + input_sample: Tensor, + prediction: Union[Dict, Tensor, Tuple[Tensor, Tensor]], + target: Tensor, + *args, + **kwargs + ) -> Dict: + + if isinstance(prediction, (Tuple, Tensor)): + seg_loss = super().forward( + input_sample=input_sample, + prediction=prediction, + target=target, + *args, + **kwargs + ) + return {"total_loss": seg_loss} + elif isinstance(prediction, Dict): + if not {"augmented_tensor", "segmentation_output"}.issubset( + prediction.keys() + ): + logger.error( + "Prediction needs to be an instance of Dict and must contain segmentation_output and augmented_tensor" + " as keys. Got keys: {}".format(prediction.keys()) + ) + augmented_tensor = prediction.get("augmented_tensor", None) + segmentation_output = prediction.get("segmentation_output", None) + if augmented_tensor is None: + seg_loss = SegCrossEntropy.forward( + self, + input_sample=input_sample, + prediction=segmentation_output, + target=target, + *args, + **kwargs + ) + return {"total_loss": seg_loss} + + seg_loss = SegCrossEntropy.forward( + self, + input_sample=input_sample, + prediction=segmentation_output, + target=target, + *args, + **kwargs + ) + + loss_na = self.forward_neural_aug( + input_tensor=input_sample, + augmented_tensor=augmented_tensor, + *args, + **kwargs + ) + + return { + "total_loss": loss_na + seg_loss, + "na_loss": loss_na, + "seg_loss": seg_loss, + } + else: + raise NotImplementedError + + def __repr__(self): + repr_str = ( + "{}(\n\tweighted_loss={}\n\tignore_idx={}\n\tlabel_smoothing={}{}".format( + self.__class__.__name__, + self.weighted_loss, + self.ignore_idx, + self.label_smoothing, + self.repr_na(), + ) + ) + + if self.aux_wt > 0: + repr_str += "\n\taux_wt={}".format(self.aux_wt) + return repr_str + "\n)" diff --git a/Adaptive Frequency Filters/loss_landscape/__init__.py b/Adaptive Frequency Filters/loss_landscape/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Adaptive Frequency Filters/loss_landscape/landscape_utils.py b/Adaptive Frequency Filters/loss_landscape/landscape_utils.py new file mode 100644 index 0000000..45e8072 --- /dev/null +++ b/Adaptive Frequency Filters/loss_landscape/landscape_utils.py @@ -0,0 +1,135 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import copy + +import torch +import numpy as np +import matplotlib.pyplot as plt +from matplotlib import cm +from matplotlib import animation +from mpl_toolkits.mplot3d import Axes3D +from typing import Optional, Dict + +# https://github.com/xxxnell/how-do-vits-work + + +def rand_basis(ws: Dict, device: Optional[str] = torch.device("cpu")): + return {k: torch.randn(size=v.shape, device=device) for k, v in ws.items()} + + +def normalize_filter(bs: Dict, ws: Dict): + bs = {k: v.float() for k, v in bs.items()} + ws = {k: v.float() for k, v in ws.items()} + + norm_bs = {} + for k in bs: + ws_norm = torch.norm(ws[k], dim=0, keepdim=True) + bs_norm = torch.norm(bs[k], dim=0, keepdim=True) + norm_bs[k] = ws_norm / (bs_norm + 1e-7) * bs[k] + + return norm_bs + + +def ignore_bn(ws: Dict): + ignored_ws = {} + for k in ws: + if len(ws[k].size()) < 2: + ignored_ws[k] = torch.zeros(size=ws[k].size(), device=ws[k].device) + else: + ignored_ws[k] = ws[k] + return ignored_ws + + +def create_bases( + model: torch.nn.Module, + device: Optional[str] = torch.device("cpu"), + has_module: Optional[bool] = False, +): + weight_state_0 = ( + copy.deepcopy(model.module.state_dict()) + if has_module + else copy.deepcopy(model.state_dict()) + ) + bases = [rand_basis(weight_state_0, device) for _ in range(2)] # Use two bases + bases = [normalize_filter(bs, weight_state_0) for bs in bases] + bases = [ignore_bn(bs) for bs in bases] + + return bases + + +def generate_plots(xx, yy, zz, model_name, results_loc): + zz = np.log(zz) + + plt.figure(figsize=(10, 10)) + plt.contour(xx, yy, zz) + plt.savefig(f"{results_loc}/{model_name}_log_contour.png", dpi=100) + plt.close() + + ## 3D plot + fig, ax = plt.subplots(subplot_kw={"projection": "3d"}) + ax.set_axis_off() + surf = ax.plot_surface(xx, yy, zz, cmap=cm.coolwarm, linewidth=0, antialiased=False) + ax.set_xlim(-1, 1) + ax.set_ylim(-1, 1) + + plt.savefig( + f"{results_loc}/{model_name}_log_surface.png", + dpi=100, + format="png", + bbox_inches="tight", + ) + plt.close() + + fig = plt.figure(figsize=(10, 10)) + ax = Axes3D(fig) + ax.set_axis_off() + + def init(): + ax.plot_surface(xx, yy, zz, cmap=cm.coolwarm, linewidth=0, antialiased=False) + ax.set_xlim(-1, 1) + ax.set_ylim(-1, 1) + return (fig,) + + def animate(i): + ax.view_init(elev=(15 * (i // 15) + i % 15) + 0.0, azim=i) + ax.set_xlim(-1, 1) + ax.set_ylim(-1, 1) + return (fig,) + + anim = animation.FuncAnimation( + fig, animate, init_func=init, frames=100, interval=20, blit=True + ) + + anim.save( + f"{results_loc}/{model_name}_log_surface.gif", fps=15, writer="imagemagick" + ) + + +def plot_save_graphs( + save_dir: str, + model_name: str, + grid_a: np.ndarray, + grid_b: np.ndarray, + loss_surface: np.ndarray, + resolution: int, +): + np.save(f"{save_dir}/{model_name}_xx.npy", grid_a) + np.save(f"{save_dir}/{model_name}_yy.npy", grid_b) + np.save(f"{save_dir}/{model_name}_zz.npy", loss_surface) + + plt.figure(figsize=(10, 10)) + plt.contour(grid_a, grid_b, loss_surface) + plt.savefig(f"{save_dir}/{model_name}_contour_res_{resolution}.png", dpi=100) + plt.close() + + generate_plots( + xx=grid_a, + yy=grid_b, + zz=loss_surface, + model_name=model_name, + results_loc=save_dir, + ) diff --git a/Adaptive Frequency Filters/main_eval.py b/Adaptive Frequency Filters/main_eval.py new file mode 100644 index 0000000..0e4f851 --- /dev/null +++ b/Adaptive Frequency Filters/main_eval.py @@ -0,0 +1,161 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +import multiprocessing + + +from affnet import get_model +from data import create_eval_loader +from engine import Evaluator + +from options.opts import get_eval_arguments +from utils import logger +from utils.common_utils import device_setup, create_directories +from utils.ddp_utils import is_master, distributed_init + + +def main(opts, **kwargs): + num_gpus = getattr(opts, "dev.num_gpus", 0) # defaults are for CPU + dev_id = getattr(opts, "dev.device_id", torch.device("cpu")) + device = getattr(opts, "dev.device", torch.device("cpu")) + is_distributed = getattr(opts, "ddp.use_distributed", False) + + # set-up data loaders + val_loader = create_eval_loader(opts) + + # set-up the model + model = get_model(opts) + + # memory format + memory_format = ( + torch.channels_last + if getattr(opts, "common.channels_last", False) + else torch.contiguous_format + ) + + is_master_node = is_master(opts) + if num_gpus <= 1: + model = model.to(device=device, memory_format=memory_format) + elif is_distributed: + model = model.to(device=device, memory_format=memory_format) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[dev_id], output_device=dev_id + ) + if is_master_node: + logger.log("Using DistributedDataParallel for evaluation") + else: + model = model.to(memory_format=memory_format) + model = torch.nn.DataParallel(model) + model = model.to(device=device) + if is_master_node: + logger.log("Using DataParallel for evaluation") + + eval_engine = Evaluator(opts=opts, model=model, eval_loader=val_loader) + eval_engine.run() + + +def distributed_worker(i, main, opts, kwargs): + setattr(opts, "dev.device_id", i) + torch.cuda.set_device(i) + setattr(opts, "dev.device", torch.device(f"cuda:{i}")) + + ddp_rank = getattr(opts, "ddp.rank", None) + if ddp_rank is None: # torch.multiprocessing.spawn + ddp_rank = kwargs.get("start_rank", 0) + i + setattr(opts, "ddp.rank", ddp_rank) + + node_rank = distributed_init(opts) + setattr(opts, "ddp.rank", node_rank) + main(opts, **kwargs) + + +def main_worker(**kwargs): + opts = get_eval_arguments() + print(opts) + # device set-up + opts = device_setup(opts) + + node_rank = getattr(opts, "ddp.rank", 0) + if node_rank < 0: + logger.error("--rank should be >=0. Got {}".format(node_rank)) + + is_master_node = is_master(opts) + + # create the directory for saving results + save_dir = getattr(opts, "common.results_loc", "results") + run_label = getattr(opts, "common.run_label", "run_1") + exp_dir = "{}/{}".format(save_dir, run_label) + setattr(opts, "common.exp_loc", exp_dir) + create_directories(dir_path=exp_dir, is_master_node=is_master_node) + + world_size = getattr(opts, "ddp.world_size", 1) + num_gpus = getattr(opts, "dev.num_gpus", 1) + use_distributed = getattr(opts, "ddp.enable", False) + if num_gpus <= 1: + use_distributed = False + setattr(opts, "ddp.use_distributed", use_distributed) + + # No of data workers = no of CPUs (if not specified or -1) + n_cpus = multiprocessing.cpu_count() + dataset_workers = getattr(opts, "dataset.workers", -1) + + if use_distributed: + if world_size == -1: + logger.log( + "Setting --ddp.world-size the same as the number of available gpus" + ) + world_size = num_gpus + setattr(opts, "ddp.world_size", world_size) + elif world_size != num_gpus: + logger.log( + "--ddp.world-size does not match num. of available GPUs. Got {} !={}".format( + world_size, num_gpus + ) + ) + logger.log("Setting --ddp.world-size={}".format(num_gpus)) + world_size = num_gpus + setattr(opts, "ddp.world_size", world_size) + + if dataset_workers == -1 or dataset_workers is None: + setattr(opts, "dataset.workers", n_cpus // world_size) + + start_rank = getattr(opts, "ddp.rank", 0) + setattr(opts, "ddp.rank", None) + kwargs["start_rank"] = start_rank + torch.multiprocessing.spawn( + fn=distributed_worker, + args=(main, opts, kwargs), + nprocs=num_gpus, + ) + else: + if dataset_workers == -1: + setattr(opts, "dataset.workers", n_cpus) + + # adjust the batch size + train_bsize = getattr(opts, "dataset.train_batch_size0", 32) * max(1, num_gpus) + val_bsize = getattr(opts, "dataset.val_batch_size0", 32) * max(1, num_gpus) + setattr(opts, "dataset.train_batch_size0", train_bsize) + setattr(opts, "dataset.val_batch_size0", val_bsize) + setattr(opts, "dev.device_id", None) + main(opts=opts, **kwargs) + + +# for segmentation and detection, we follow a different evaluation pipeline that allows to save the results too +def main_worker_segmentation(**kwargs): + from engine.eval_segmentation import main_segmentation_evaluation + + main_segmentation_evaluation(**kwargs) + + +def main_worker_detection(**kwargs): + from engine.eval_detection import main_detection_evaluation + + main_detection_evaluation(**kwargs) + + +if __name__ == "__main__": + main_worker() \ No newline at end of file diff --git a/Adaptive Frequency Filters/main_train.py b/Adaptive Frequency Filters/main_train.py new file mode 100644 index 0000000..1f5f691 --- /dev/null +++ b/Adaptive Frequency Filters/main_train.py @@ -0,0 +1,302 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import multiprocessing +import torch +import math +from torch.cuda.amp import GradScaler +from torch.distributed.elastic.multiprocessing import errors + +from utils import logger +from options.opts import get_training_arguments +from utils.common_utils import device_setup, create_directories +from utils.ddp_utils import is_master, distributed_init +from affnet import get_model, EMA +from loss_fn import build_loss_fn +from optim import build_optimizer +from optim.scheduler import build_scheduler +from data import create_train_val_loader +from utils.checkpoint_utils import load_checkpoint, load_model_state +from engine import Trainer +from common import ( + DEFAULT_EPOCHS, + DEFAULT_ITERATIONS, + DEFAULT_MAX_ITERATIONS, + DEFAULT_MAX_EPOCHS, +) +# from torchstat import stat +# from torch.profiler import profile, record_function, ProfilerActivity + +try: + import wandb + has_wandb = True +except ImportError: + has_wandb = False + + +@errors.record +def main(opts, **kwargs): + if is_master(opts): + print("init wandb") + opts.log_wandb = opts.log_wandb and has_wandb + if opts.log_wandb: + wandb.init(project='mobilevit', + config=opts, + name=opts.experiment, + id=opts.experiment, + tags=[ + 'backbone', + ], + settings=wandb.Settings(start_method="fork"), + resume='allow', + ) + num_gpus = getattr(opts, "dev.num_gpus", 0) # defaults are for CPU + dev_id = getattr(opts, "dev.device_id", torch.device("cpu")) + device = getattr(opts, "dev.device", torch.device("cpu")) + is_distributed = getattr(opts, "ddp.use_distributed", False) + + is_master_node = is_master(opts) + + # set-up data loaders + train_loader, val_loader, train_sampler = create_train_val_loader(opts) + + # compute max iterations based on max epochs + # Useful in doing polynomial decay + is_iteration_based = getattr(opts, "scheduler.is_iteration_based", False) + if is_iteration_based: + max_iter = getattr(opts, "scheduler.max_iterations", DEFAULT_ITERATIONS) + if max_iter is None or max_iter <= 0: + logger.log("Setting max. iterations to {}".format(DEFAULT_ITERATIONS)) + setattr(opts, "scheduler.max_iterations", DEFAULT_ITERATIONS) + max_iter = DEFAULT_ITERATIONS + setattr(opts, "scheduler.max_epochs", DEFAULT_MAX_EPOCHS) + if is_master_node: + logger.log("Max. iteration for training: {}".format(max_iter)) + else: + max_epochs = getattr(opts, "scheduler.max_epochs", DEFAULT_EPOCHS) + if max_epochs is None or max_epochs <= 0: + logger.log("Setting max. epochs to {}".format(DEFAULT_EPOCHS)) + setattr(opts, "scheduler.max_epochs", DEFAULT_EPOCHS) + setattr(opts, "scheduler.max_iterations", DEFAULT_MAX_ITERATIONS) + max_epochs = getattr(opts, "scheduler.max_epochs", DEFAULT_EPOCHS) + if is_master_node: + logger.log("Max. epochs for training: {}".format(max_epochs)) + # set-up the model + model = get_model(opts) + + # # add by huangzp + # stat(model, (3, 224, 224)) + # + # # torch profiler + # with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof: + # with record_function("model_inference"): + # model(input) + + # memory format + memory_format = ( + torch.channels_last + if getattr(opts, "common.channels_last", False) + else torch.contiguous_format + ) + + if num_gpus == 0: + logger.warning( + "No GPUs are available, so training on CPU. Consider training on GPU for faster training" + ) + model = model.to(device=device, memory_format=memory_format) + elif num_gpus == 1: + model = model.to(device=device, memory_format=memory_format) + elif is_distributed: + model = model.to(device=device, memory_format=memory_format) + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[dev_id], + output_device=dev_id, + find_unused_parameters=getattr(opts, "ddp.find_unused_params", False), + ) + if is_master_node: + logger.log("Using DistributedDataParallel for training") + else: + model = model.to(memory_format=memory_format) + model = torch.nn.DataParallel(model) + model = model.to(device=device) + if is_master_node: + logger.log("Using DataParallel for training") + + # setup criteria + criteria = build_loss_fn(opts) + criteria = criteria.to(device=device) + + # create the optimizer + optimizer = build_optimizer(model, opts=opts) + + # create the gradient scalar + gradient_scalar = GradScaler(enabled=getattr(opts, "common.mixed_precision", False)) + + # LR scheduler + scheduler = build_scheduler(opts=opts) + + model_ema = None + use_ema = getattr(opts, "ema.enable", False) + + if use_ema: + ema_momentum = getattr(opts, "ema.momentum", 0.0001) + model_ema = EMA(model=model, ema_momentum=ema_momentum, device=device) + if is_master_node: + logger.log("Using EMA") + + best_metric = ( + 0.0 if getattr(opts, "stats.checkpoint_metric_max", False) else math.inf + ) + + start_epoch = 0 + start_iteration = 0 + resume_loc = getattr(opts, "common.resume", None) + finetune_loc = getattr(opts, "common.finetune_imagenet1k", None) + auto_resume = getattr(opts, "common.auto_resume", False) + if resume_loc is not None or auto_resume: + ( + model, + optimizer, + gradient_scalar, + start_epoch, + start_iteration, + best_metric, + model_ema, + ) = load_checkpoint( + opts=opts, + model=model, + optimizer=optimizer, + model_ema=model_ema, + gradient_scalar=gradient_scalar, + ) + elif finetune_loc is not None: + model, model_ema = load_model_state(opts=opts, model=model, model_ema=model_ema) + if is_master_node: + logger.log("Finetuning model from checkpoint {}".format(finetune_loc)) + + training_engine = Trainer( + opts=opts, + model=model, + validation_loader=val_loader, + training_loader=train_loader, + optimizer=optimizer, + criterion=criteria, + scheduler=scheduler, + start_epoch=start_epoch, + start_iteration=start_iteration, + best_metric=best_metric, + model_ema=model_ema, + gradient_scalar=gradient_scalar, + ) + + training_engine.run(train_sampler=train_sampler) + + +def distributed_worker(i, main, opts, kwargs): + setattr(opts, "dev.device_id", i) + torch.cuda.set_device(i) + setattr(opts, "dev.device", torch.device(f"cuda:{i}")) + + ddp_rank = getattr(opts, "ddp.rank", None) + if ddp_rank is None: # torch.multiprocessing.spawn + ddp_rank = kwargs.get("start_rank", 0) + i + setattr(opts, "ddp.rank", ddp_rank) + + node_rank = distributed_init(opts) + setattr(opts, "ddp.rank", node_rank) + main(opts, **kwargs) + + +def main_worker(**kwargs): + opts = get_training_arguments() + print(opts) + # device set-up + opts = device_setup(opts) + + node_rank = getattr(opts, "ddp.rank", 0) + if node_rank < 0: + logger.error("--rank should be >=0. Got {}".format(node_rank)) + + is_master_node = is_master(opts) + + # create the directory for saving results + save_dir = getattr(opts, "common.results_loc", "results") + run_label = getattr(opts, "common.run_label", "run_1") + exp_dir = "{}/{}".format(save_dir, run_label) + setattr(opts, "common.exp_loc", exp_dir) + create_directories(dir_path=exp_dir, is_master_node=is_master_node) + + num_gpus = getattr(opts, "dev.num_gpus", 1) + world_size = getattr(opts, "ddp.world_size", -1) + use_distributed = not getattr(opts, "ddp.disable", False) + if num_gpus <= 1: + use_distributed = False + setattr(opts, "ddp.use_distributed", use_distributed) + + # No of data workers = no of CPUs (if not specified or -1) + n_cpus = multiprocessing.cpu_count() + dataset_workers = getattr(opts, "dataset.workers", -1) + + norm_name = getattr(opts, "model.normalization.name", "batch_norm") + ddp_spawn = not getattr(opts, "ddp.no_spawn", False) + if use_distributed and ddp_spawn and torch.cuda.is_available(): + # get device id + dev_id = getattr(opts, "ddp.device_id", None) + setattr(opts, "dev.device_id", dev_id) + + if world_size == -1: + logger.log( + "Setting --ddp.world-size the same as the number of available gpus" + ) + world_size = num_gpus + setattr(opts, "ddp.world_size", world_size) + + if dataset_workers == -1 or dataset_workers is None: + setattr(opts, "dataset.workers", n_cpus // num_gpus) + + start_rank = getattr(opts, "ddp.rank", 0) + setattr(opts, "ddp.rank", None) + kwargs["start_rank"] = start_rank + setattr(opts, "ddp.start_rank", start_rank) + + # add by huangzp: fix the distribution error of rr1 + def _find_free_port(): + import socket + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Binding to port 0 will cause the OS to find an available port for us + sock.bind(("", 0)) + port = sock.getsockname()[1] + sock.close() + # NOTE: there is still a chance the port could be taken by other processes. + return port + port = _find_free_port() + setattr(opts, "ddp.dist_port", port) + + torch.multiprocessing.spawn( + fn=distributed_worker, + args=(main, opts, kwargs), + nprocs=num_gpus, + ) + else: + if dataset_workers == -1: + setattr(opts, "dataset.workers", n_cpus) + + if norm_name in ["sync_batch_norm", "sbn"]: + setattr(opts, "model.normalization.name", "batch_norm") + + # adjust the batch size + train_bsize = getattr(opts, "dataset.train_batch_size0", 32) * max(1, num_gpus) + val_bsize = getattr(opts, "dataset.val_batch_size0", 32) * max(1, num_gpus) + setattr(opts, "dataset.train_batch_size0", train_bsize) + setattr(opts, "dataset.val_batch_size0", val_bsize) + setattr(opts, "dev.device_id", None) + main(opts=opts, **kwargs) + + +if __name__ == "__main__": + # + main_worker() diff --git a/Adaptive Frequency Filters/metrics/__init__.py b/Adaptive Frequency Filters/metrics/__init__.py new file mode 100644 index 0000000..87b6c73 --- /dev/null +++ b/Adaptive Frequency Filters/metrics/__init__.py @@ -0,0 +1,74 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import importlib +import argparse + +SUPPORTED_STATS = ["loss", "grad_norm"] + + +def register_stats_fn(name): + def register_fn(fn): + if name in SUPPORTED_STATS: + raise ValueError("Cannot register duplicate state ({})".format(name)) + SUPPORTED_STATS.append(name) + return fn + + return register_fn + + +def arguments_stats(parser: argparse.ArgumentParser): + group = parser.add_argument_group(title="Statistics", description="Statistics") + group.add_argument( + "--stats.val", type=str, default=["loss"], nargs="+", help="Name of statistics" + ) + group.add_argument( + "--stats.train", + type=str, + default=["loss"], + nargs="+", + help="Name of statistics", + ) + group.add_argument( + "--stats.checkpoint-metric", + type=str, + default="loss", + help="Metric to use for saving checkpoints", + ) + group.add_argument( + "--stats.checkpoint-metric-max", + action="store_true", + default=False, + help="Maximize checkpoint metric", + ) + group.add_argument( + "--stats.coco-map.iou_types", + type=str, + default=["bbox"], + nargs="+", + choices=("bbox", "segm"), + help="Types of IOU to compute for MSCoco.", + ) + + return parser + + +# automatically import different metrics +metrics_dir = os.path.dirname(__file__) +for file in os.listdir(metrics_dir): + path = os.path.join(metrics_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("metrics." + model_name) + + +from metrics.stats import Statistics +from metrics.metric_monitor import metric_monitor diff --git a/Adaptive Frequency Filters/metrics/coco_map.py b/Adaptive Frequency Filters/metrics/coco_map.py new file mode 100644 index 0000000..0c36860 --- /dev/null +++ b/Adaptive Frequency Filters/metrics/coco_map.py @@ -0,0 +1,245 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- +import numpy as np +import torch +from torch.nn import functional as F +from typing import Optional, Dict, List +import io +import os +from pycocotools.cocoeval import COCOeval +from pycocotools.coco import COCO +from pycocotools import mask as maskUtils +from contextlib import redirect_stdout + +from affnet.models.detection.base_detection import DetectionPredTuple +from utils.tensor_utils import all_gather_list +from utils import logger +from utils.ddp_utils import is_master + +from . import register_stats_fn + + +@register_stats_fn(name="coco_map") +class COCOEvaluator(object): + def __init__( + self, + opts, + split: Optional[str] = "val", + year: Optional[int] = 2017, + use_distributed: Optional[bool] = False, + *args, + **kwargs + ): + # disable printing on console, so that pycocotools print statements are not printed on console + logger.disable_printing() + bkrnd_id = ( + 0 if getattr(opts, "dataset.detection.no_background_id", False) else 1 + ) + + iou_types = getattr(opts, "stats.coco_map.iou_types", ["bbox"]) + + root = getattr(opts, "dataset.root_val", None) + ann_file = os.path.join( + root, "annotations/instances_{}{}.json".format(split, year) + ) + coco_gt = COCO(ann_file) + + coco_categories = sorted(coco_gt.getCatIds()) + self.coco_id_to_contiguous_id = { + coco_id: i + bkrnd_id for i, coco_id in enumerate(coco_categories) + } + self.contiguous_id_to_coco_id = { + v: k for k, v in self.coco_id_to_contiguous_id.items() + } + + self.coco_gt = coco_gt + self.iou_types = iou_types + self.use_distributed = use_distributed + self.is_master_node = is_master(opts) + + self.coco_results = {iou_type: [] for iou_type in iou_types} + + # enable printing, to enable affnet log printing + logger.enable_printing() + + def prepare_predictions(self, predictions: Dict, targets: List): + if not ( + isinstance(predictions, Dict) + and ({"detections"} <= set(list(predictions.keys()))) + ): + logger.error( + "For coco evaluation during training, the output from the model should be a dictionary " + "and should contain the results in a key called detections" + ) + + detections = predictions["detections"] + + if isinstance(targets, list): + image_ids = torch.tensor( + [t["image_id"] for t in targets], dtype=torch.int64 + ) + image_widths = torch.tensor( + [t["image_width"] for t in targets], dtype=torch.int64 + ) + image_heights = torch.tensor( + [t["image_height"] for t in targets], dtype=torch.int64 + ) + else: + image_ids = targets["image_id"] + image_widths = targets["image_width"] + image_heights = targets["image_height"] + + if isinstance(detections, DetectionPredTuple): + detections = [detections] + + if not ( + isinstance(detections, List) + and isinstance(detections[0], DetectionPredTuple) + ): + logger.error( + "For coco evaluation during training, the results should be stored as a List of DetectionPredTuple" + ) + + self.prepare_cache_results( + detection_results=detections, + image_ids=image_ids, + image_widths=image_widths, + image_heights=image_heights, + ) + + def prepare_cache_results( + self, + detection_results: List[DetectionPredTuple], + image_ids, + image_widths, + image_heights, + ) -> None: + batch_results = {k: [] for k in self.coco_results.keys()} + for detection_result, img_id, img_w, img_h in zip( + detection_results, image_ids, image_widths, image_heights + ): + label = detection_result.labels + + if label.numel() == 0: + # no detections + continue + box = detection_result.boxes + score = detection_result.scores + + img_id, img_w, img_h = img_id.item(), img_w.item(), img_h.item() + + box[..., 0::2] = torch.clip(box[..., 0::2] * img_w, min=0, max=img_w) + box[..., 1::2] = torch.clip(box[..., 1::2] * img_h, min=0, max=img_h) + + # convert box from xyxy to xywh format + box[..., 2] = box[..., 2] - box[..., 0] + box[..., 3] = box[..., 3] - box[..., 1] + + box = box.cpu().numpy() + label = label.cpu().numpy() + score = score.cpu().numpy() + + if "bbox" in batch_results: + batch_results["bbox"].extend( + [ + { + "image_id": img_id, + "category_id": self.contiguous_id_to_coco_id[ + label[bbox_id] + ], + "bbox": box[bbox_id].tolist(), + "score": score[bbox_id], + } + for bbox_id in range(box.shape[0]) + if label[bbox_id] > 0 + ] + ) + + masks = detection_result.masks + if masks is not None and "segm" in batch_results: + # masks are [N, H, W]. For interpolation, convert them to [1, N, H, W] and then back to [N, H, W] + masks = F.interpolate( + masks.unsqueeze(0), size=(img_h, img_w), mode="bilinear", align_corners=True + ).squeeze(0) + masks = masks > 0.5 + + masks = masks.cpu().numpy() + # predicted masks are in [N, H, W] format + rles = [ + maskUtils.encode( + np.array(mask[:, :, np.newaxis], dtype=np.uint8, order="F") + )[0] + for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + + batch_results["segm"].extend( + [ + { + "image_id": img_id, + "category_id": self.contiguous_id_to_coco_id[label[seg_id]], + "segmentation": rle, + "score": score[seg_id], + } + for seg_id, rle in enumerate(rles) + if label[seg_id] > 0 + ] + ) + + for k in batch_results.keys(): + self.coco_results[k].extend(batch_results[k]) + + def gather_coco_results(self) -> None: + # synchronize results across different devices + for iou_type, coco_results in self.coco_results.items(): + # agg_coco_results as List[List]. + # The outer list is for processes and inner list is for coco_results in the process + if self.use_distributed: + agg_coco_results = all_gather_list(coco_results) + + merged_coco_results = [] + # filter the duplicates + for ( + p_coco_results + ) in agg_coco_results: # retrieve results from each process + merged_coco_results.extend(p_coco_results) + else: + merged_coco_results = coco_results + + self.coco_results[iou_type] = merged_coco_results + + def summarize_coco_results(self) -> Dict: + stats_map = dict() + for iou_type, coco_results in self.coco_results.items(): + if len(coco_results) < 1: + # during initial epochs, we may not have any sample results, so we can skip this part + map_val = 0.0 + else: + try: + logger.disable_printing() + + with redirect_stdout(io.StringIO()): + coco_dt = COCO.loadRes(self.coco_gt, coco_results) + + coco_eval = COCOeval( + cocoGt=self.coco_gt, cocoDt=coco_dt, iouType=iou_type + ) + coco_eval.evaluate() + coco_eval.accumulate() + + if self.is_master_node: + logger.enable_printing() + + logger.log("Results for IoU Metric: {}".format(iou_type)) + coco_eval.summarize() + map_val = coco_eval.stats[0].item() + except Exception as e: + map_val = 0.0 + stats_map[iou_type] = map_val * 100 + + logger.enable_printing() + return stats_map diff --git a/Adaptive Frequency Filters/metrics/confusion_mat.py b/Adaptive Frequency Filters/metrics/confusion_mat.py new file mode 100644 index 0000000..18705c1 --- /dev/null +++ b/Adaptive Frequency Filters/metrics/confusion_mat.py @@ -0,0 +1,43 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch + + +class ConfusionMatrix(object): + """ + Computes the confusion matrix and is based on `FCN `_ + """ + + def __init__(self): + self.confusion_mat = None + + def update(self, ground_truth, prediction, n_classes): + if self.confusion_mat is None: + self.confusion_mat = torch.zeros( + (n_classes, n_classes), dtype=torch.int64, device=ground_truth.device + ) + with torch.no_grad(): + k = (ground_truth >= 0) & (ground_truth < n_classes) + inds = n_classes * ground_truth[k].to(torch.int64) + prediction[k] + self.confusion_mat += torch.bincount( + inds, minlength=n_classes**2 + ).reshape(n_classes, n_classes) + + def reset(self): + if self.confusion_mat is not None: + self.confusion_mat.zero_() + + def compute(self): + if self.confusion_mat is None: + print("Confusion matrix is None. Check code") + return None + h = self.confusion_mat.float() + acc_global = torch.diag(h).sum() / h.sum() + diag_h = torch.diag(h) + acc = diag_h / h.sum(1) + iu = diag_h / (h.sum(1) + h.sum(0) - diag_h) + return acc_global, acc, iu diff --git a/Adaptive Frequency Filters/metrics/intersection_over_union.py b/Adaptive Frequency Filters/metrics/intersection_over_union.py new file mode 100644 index 0000000..69ac534 --- /dev/null +++ b/Adaptive Frequency Filters/metrics/intersection_over_union.py @@ -0,0 +1,50 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import Tensor +from typing import Optional, Tuple, Union + +from . import register_stats_fn + + +@register_stats_fn(name="iou") +def compute_miou_batch( + prediction: Union[Tuple[Tensor, Tensor], Tensor], + target: Tensor, + epsilon: Optional[float] = 1e-7, +): + if isinstance(prediction, Tuple) and len(prediction) == 2: + mask = prediction[0] + assert isinstance(mask, Tensor) + elif isinstance(prediction, Tensor): + mask = prediction + assert isinstance(mask, Tensor) + else: + raise NotImplementedError( + "For computing loss for segmentation task, we need prediction to be an instance of Tuple or Tensor" + ) + + num_classes = mask.shape[1] + pred_mask = torch.max(mask, dim=1)[1] + assert ( + pred_mask.dim() == 3 + ), "Predicted mask tensor should be 3-dimensional (B x H x W)" + + pred_mask = pred_mask.byte() + target = target.byte() + + # shift by 1 so that 255 is 0 + pred_mask += 1 + target += 1 + + pred_mask = pred_mask * (target > 0) + inter = pred_mask * (pred_mask == target) + area_inter = torch.histc(inter.float(), bins=num_classes, min=1, max=num_classes) + area_pred = torch.histc(pred_mask.float(), bins=num_classes, min=1, max=num_classes) + area_mask = torch.histc(target.float(), bins=num_classes, min=1, max=num_classes) + area_union = area_pred + area_mask - area_inter + epsilon + return area_inter, area_union diff --git a/Adaptive Frequency Filters/metrics/metric_monitor.py b/Adaptive Frequency Filters/metrics/metric_monitor.py new file mode 100644 index 0000000..de413d5 --- /dev/null +++ b/Adaptive Frequency Filters/metrics/metric_monitor.py @@ -0,0 +1,315 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import Tensor +from typing import Optional, Tuple, Any, Dict, Union +from numbers import Number + +from utils.tensor_utils import tensor_to_python_float +from utils import logger + +from .topk_accuracy import top_k_accuracy +from .intersection_over_union import compute_miou_batch +from .psnr import compute_psnr + + +def metric_monitor( + opts, + pred_label: Any, + target_label: Any, + loss: Tensor or float, + metric_names: list, + use_distributed: Optional[bool] = False, + grad_norm: Optional = None, + is_evaluation: Optional[bool] = False, + *args, + **kwargs +): + """ + This function aggregate different metrics and convert them into floats, so that + they can be easily consumed by stats.py file + """ + metric_vals = dict() + if "loss" in metric_names: + metric_vals["loss"] = gather_loss(loss, is_distributed=use_distributed) + + if "grad_norm" in metric_names: + metric_vals["grad_norm"] = gather_grad_norm( + grad_norm, is_distributed=use_distributed + ) + + if "top1" in metric_names: + top_1, top_5 = gather_top_k_metrics( + prediction=pred_label, target=target_label, is_distributed=use_distributed + ) + metric_vals["top1"] = top_1 + if "top5" in metric_names: + metric_vals["top5"] = top_5 + + if "iou" in metric_names: + inter, union = gather_iou_metrics( + prediction=pred_label, target=target_label, is_distributed=use_distributed + ) + metric_vals["iou"] = {"inter": inter, "union": union} + + if "psnr" in metric_names: + psnr = compute_psnr(prediction=pred_label, target=target_label) + metric_vals["psnr"] = tensor_to_python_float( + psnr, is_distributed=use_distributed + ) + + return metric_vals + + +def gather_loss( + loss: Union[Tensor, Dict], is_distributed: bool +) -> Union[Number, Dict[str, Number]]: + """ + This function gather losses from different processes and converts to float. + """ + if isinstance(loss, (int, float)): + return loss * 1.0 + elif isinstance(loss, Tensor): + return tensor_to_python_float(loss, is_distributed=is_distributed) + elif isinstance(loss, Dict): + loss_dict = {} + + if "total_loss" not in list(loss.keys()): + logger.error( + "total_loss key is required for loss functions that return outputs as dictionary." + ) + + for k, v in loss.items(): + if v is None: + continue + v_float = tensor_to_python_float(v, is_distributed=is_distributed) + loss_dict[k] = v_float + return loss_dict + else: + logger.error("Metric monitor supports Tensor or Dict of Tensors") + + +def gather_grad_norm( + grad_norm: Union[Tensor, Dict], is_distributed: bool +) -> Union[Number, Dict[str, Number]]: + """ + This function gather grad_norm from different processes and converts to float. + """ + if grad_norm is None: + return 1e-7 + + if isinstance(grad_norm, (int, float)): + return grad_norm * 1.0 + if isinstance(grad_norm, Tensor): + return tensor_to_python_float(grad_norm, is_distributed=is_distributed) + elif isinstance(grad_norm, Dict): + grad_norm_dict = {} + for k, v in grad_norm.items(): + if v is None: + continue + v_float = tensor_to_python_float(v, is_distributed=is_distributed) + grad_norm_dict[k] = v_float + return grad_norm_dict + else: + logger.error("Metric monitor supports Tensor or Dict of Tensors") + + +def gather_top_k_metrics( + prediction: Union[Tensor, Dict], target: Union[Tensor, Dict], is_distributed: bool +) -> Union[Tuple[Number, Number], Tuple[Dict[str, Number], Dict[str, Number]]]: + """ + This function gather top-1 and top-5 metrics from different processes and converts to float. + """ + # We have four combinations between prediction and target types: + # 1. (Tensor, Tensor) + # 2. (Dict, Tensor) + # 3. (Dict, Dict) + # 4. (Tensor, Dict) --> This combination is rare + + if isinstance(prediction, Tensor) and isinstance(target, Tensor): + top_1_acc, top_5_acc = top_k_accuracy(prediction, target, top_k=(1, 5)) + top_1_acc = tensor_to_python_float(top_1_acc, is_distributed=is_distributed) + top_5_acc = tensor_to_python_float(top_5_acc, is_distributed=is_distributed) + return top_1_acc, top_5_acc + elif isinstance(prediction, Dict) and isinstance(target, Tensor): + top1_dict = {} + top5_dict = {} + for pred_k, pred_v in prediction.items(): + if ( + isinstance(pred_v, Tensor) and pred_v.dim() == 2 and target.dim() == 1 + ): # Output tensor should be of size [batch_size, num_classes] and target should be of shape [batch_size] + top_1_acc, top_5_acc = top_k_accuracy(pred_v, target, top_k=(1, 5)) + top_1_acc = tensor_to_python_float( + top_1_acc, is_distributed=is_distributed + ) + top_5_acc = tensor_to_python_float( + top_5_acc, is_distributed=is_distributed + ) + top1_dict[pred_k] = top_1_acc + top5_dict[pred_k] = top_5_acc + return top1_dict, top5_dict + elif isinstance(prediction, Dict) and isinstance(target, Dict): + # prediction and target dictionaries should have intersecting keys + prediction_keys = prediction.keys() + target_keys = target.keys() + + intersection_keys = list(set(prediction_keys).intersection(target_keys)) + if len(intersection_keys) == 0: + logger.error( + "The keys in prediction and target are different. " + " Got: Prediction keys={} and Target keys={}".format( + prediction_keys, target_keys + ) + ) + + top1_dict = {} + top5_dict = {} + for pred_k in intersection_keys: + pred_v = prediction[pred_k] + target_v = target[pred_k] + if ( + isinstance(pred_v, Tensor) + and isinstance(target_v, Tensor) + and pred_v.dim() == 2 + and target_v.dim() == 1 + ): # Output tensor should be of size [batch_size, num_classes] and target should be of shape [batch_size] + top_1_acc, top_5_acc = top_k_accuracy(pred_v, target_v, top_k=(1, 5)) + top_1_acc = tensor_to_python_float( + top_1_acc, is_distributed=is_distributed + ) + top_5_acc = tensor_to_python_float( + top_5_acc, is_distributed=is_distributed + ) + top1_dict[pred_k] = top_1_acc + top5_dict[pred_k] = top_5_acc + return top1_dict, top5_dict + elif isinstance(prediction, Tensor) and isinstance(target, Dict): + # rare but possible + top1_dict = {} + top5_dict = {} + for target_k, target_v in target.items(): + if ( + isinstance(target_v, Tensor) + and prediction.dim() == 2 + and target_v.dim() == 1 + ): # Output tensor should be of size [batch_size, num_classes] and target should be of shape [batch_size] + top_1_acc, top_5_acc = top_k_accuracy( + prediction, target_v, top_k=(1, 5) + ) + top_1_acc = tensor_to_python_float( + top_1_acc, is_distributed=is_distributed + ) + top_5_acc = tensor_to_python_float( + top_5_acc, is_distributed=is_distributed + ) + top1_dict[target_k] = top_1_acc + top5_dict[target_k] = top_5_acc + return top1_dict, top5_dict + else: + logger.error("Metric monitor supports Tensor or Dict of Tensors") + + +def gather_iou_metrics( + prediction: Union[Tensor, Dict], target: Tensor, is_distributed: bool +) -> Union[Tuple[Number, Number], Tuple[Dict[str, Number], Dict[str, Number]]]: + """ + This function gathers intersection and union metrics from different processes and converts to float. + """ + if isinstance(prediction, Tensor) and isinstance(target, Tensor): + inter, union = compute_miou_batch(prediction=prediction, target=target) + inter = tensor_to_python_float(inter, is_distributed=is_distributed) + union = tensor_to_python_float(union, is_distributed=is_distributed) + return inter, union + # elif isinstance(prediction, Dict): + # logger.error("IOU metrics are not supported for a dictionary of predictions") + # We will revisit it later, as per the use case. + + # inter_dict = {} + # union_dict = {} + # for k, v in prediction.items(): + # inter, union = compute_miou_batch(prediction=v, target=target) + # inter = tensor_to_python_float(inter, is_distributed=is_distributed) + # union = tensor_to_python_float(union, is_distributed=is_distributed) + # inter_dict[k] = inter + # union_dict[k] = union + # return inter_dict, union_dict + else: + logger.error("Metric monitor supports Tensor only for IoU") + + +def gather_psnr_metrics( + prediction: Union[Tensor, Dict], target: Union[Tensor, Dict], is_distributed: bool +) -> Union[Number, Dict[str, Number]]: + """ + This function gathers psnr scores from different processes and converts to float. + """ + # We have four combinations between prediction and target types: + # 1. (Tensor, Tensor) + # 2. (Dict, Tensor) + # 3. (Dict, Dict) + # 4. (Tensor, Dict) --> This combination is rare + + if isinstance(prediction, Tensor) and isinstance(target, Tensor): + if prediction.numel() != target.numel(): + logger.error( + "Prediction and target have different number of elements." + "Got: Prediction={} and target={}".format( + prediction.shape, target.shape + ) + ) + psnr = compute_psnr(prediction=prediction, target=target) + psnr = tensor_to_python_float(psnr, is_distributed=is_distributed) + return psnr + elif isinstance(prediction, Dict) and isinstance(target, Tensor): + psnr_dict = {} + for pred_k, pred_v in prediction.items(): + # only compute PSNR where prediction size and target sizes are the same + if isinstance(pred_v, Tensor) and (pred_v.numel() == target.numel()): + psnr = compute_psnr(prediction=pred_v, target=target) + psnr = tensor_to_python_float(psnr, is_distributed=is_distributed) + psnr_dict[pred_k] = psnr + return psnr_dict + elif isinstance(prediction, Dict) and isinstance(target, Dict): + # prediction and target dictionaries should have intersecting keys + prediction_keys = prediction.keys() + target_keys = target.keys() + + intersection_keys = list(set(prediction_keys).intersection(target_keys)) + if len(intersection_keys) == 0: + logger.error( + "The keys in prediction and target are different. " + " Got: Prediction keys={} and Target keys={}".format( + prediction_keys, target_keys + ) + ) + + psnr_dict = {} + for pred_k in intersection_keys: + pred_v = prediction[pred_k] + target_v = target[pred_k] + # only compute PSNR where prediction size and target sizes are the same + if ( + isinstance(pred_v, Tensor) + and isinstance(target_v, Tensor) + and (pred_v.numel() == target_v.numel()) + ): + psnr = compute_psnr(prediction=pred_v, target=target_v) + psnr = tensor_to_python_float(psnr, is_distributed=is_distributed) + psnr_dict[pred_k] = psnr + return psnr_dict + elif isinstance(prediction, Tensor) and isinstance(target, Dict): + psnr_dict = {} + for target_k, target_v in target.items(): + # only compute PSNR where prediction size and target sizes are the same + if isinstance(target_v, Tensor) and ( + prediction.numel() == target_v.numel() + ): + psnr = compute_psnr(prediction=prediction, target=target_v) + psnr = tensor_to_python_float(psnr, is_distributed=is_distributed) + psnr_dict[target_k] = psnr + return psnr_dict + else: + logger.error("Metric monitor supports Tensor or Dict of Tensors") diff --git a/Adaptive Frequency Filters/metrics/psnr.py b/Adaptive Frequency Filters/metrics/psnr.py new file mode 100644 index 0000000..9b67c30 --- /dev/null +++ b/Adaptive Frequency Filters/metrics/psnr.py @@ -0,0 +1,29 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import Tensor +from typing import Optional + +from . import register_stats_fn + + +@register_stats_fn(name="psnr") +def compute_psnr( + prediction: Tensor, target: Tensor, no_uint8_conversion: Optional[bool] = False +) -> Tensor: + + if not no_uint8_conversion: + prediction = prediction.mul(255.0).to(torch.uint8) + target = target.mul(255.0).to(torch.uint8) + MAX_I = 255**2 + else: + MAX_I = 1 + + error = torch.pow(prediction - target, 2).float() + mse = torch.mean(error) + 1e-10 + psnr = 10.0 * torch.log10(MAX_I / mse) + return psnr diff --git a/Adaptive Frequency Filters/metrics/stats.py b/Adaptive Frequency Filters/metrics/stats.py new file mode 100644 index 0000000..da1c7bb --- /dev/null +++ b/Adaptive Frequency Filters/metrics/stats.py @@ -0,0 +1,234 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import sys +import time +import numpy as np +import torch +from utils import logger +from typing import Optional, Dict, Union, Any, List +from numbers import Number + +from . import SUPPORTED_STATS + + +class Statistics(object): + def __init__( + self, + metric_names: Optional[list] = ["loss"], + is_master_node: Optional[bool] = False, + ) -> None: + if len(metric_names) == 0: + logger.error("Metric names list cannot be empty") + + # key is the metric name and value is the value + metric_dict: Dict[str, Union[Any]] = {} + metric_counters = {} + for m_name in metric_names: + # Don't use coco_map key here as it is handled separately + if m_name == "coco_map": + continue + + if m_name in SUPPORTED_STATS: + metric_dict[m_name] = None + metric_counters[m_name] = 0 + else: + if is_master_node: + logger.log( + "{} statistics not supported. Supported: {}".format( + m_name, SUPPORTED_STATS + ) + ) + + self.metric_dict = metric_dict + self.supported_metrics = list(metric_dict.keys()) + self.metric_counters = metric_counters + self.round_places = 4 + self.is_master_node = is_master_node + + self.batch_time = 0 + self.batch_counter = 0 + + def update( + self, metric_vals: dict, batch_time: float, n: Optional[int] = 1 + ) -> None: + for k, v in metric_vals.items(): + if k in self.supported_metrics: + if self.metric_dict[k] is None: + if k == "iou": + if isinstance(v["inter"], np.ndarray): + self.metric_dict[k] = { + "inter": v["inter"] * n, + "union": v["union"] * n, + } + else: + logger.error( + "IOU computation is only supported using np.ndarray." + ) + elif isinstance(v, Dict): + self.metric_dict[k] = dict() + for k1, v1 in v.items(): + self.metric_dict[k][k1] = v1 * n + elif isinstance(v, Number): + self.metric_dict[k] = v * n + else: + logger.error( + "Dict[str, float] or float are supported in {}".format( + self.__class__.__name__ + ) + ) + else: + if k == "iou": + if isinstance(v["inter"], np.ndarray): + self.metric_dict[k]["inter"] += v["inter"] * n + self.metric_dict[k]["union"] += v["union"] * n + else: + logger.error( + "IOU computation is only supported using np.ndarray." + ) + elif isinstance(v, Dict): + for k1, v1 in v.items(): + self.metric_dict[k][k1] += v1 * n + elif isinstance(v, Number): + self.metric_dict[k] += v * n + else: + logger.error( + "Dict[str, float] or Number are supported in {}".format( + self.__class__.__name__ + ) + ) + + self.metric_counters[k] += n + self.batch_time += batch_time + self.batch_counter += 1 + + def avg_statistics_all(self, sep=": ") -> List[str]: + """ + This function computes average statistics of all metrics and returns them as a list of strings. + + Examples: + loss: 12.9152 + loss: {'total_loss': 12.9152, 'reg_loss': 2.8199, 'cls_loss': 10.0953} + """ + + metric_stats = [] + for k, v in self.metric_dict.items(): + counter = self.metric_counters[k] + + if k == "iou": + if isinstance(v["inter"], np.ndarray): + inter = (v["inter"] * 1.0) / counter + union = (v["union"] * 1.0) / counter + iou = inter / union + if isinstance(iou, torch.Tensor): + iou = iou.cpu().numpy() + # Converting iou from [0, 1] to [0, 100] + # other metrics are by default in [0, 100 range] + v_avg = np.mean(iou) * 100.0 + v_avg = round(v_avg, self.round_places) + else: + logger.error("IOU computation is only supported using np.ndarray.") + elif isinstance(v, Dict): + v_avg = {} + for k1, v1 in v.items(): + v_avg[k1] = round((v1 * 1.0) / counter, self.round_places) + else: + v_avg = round((v * 1.0) / counter, self.round_places) + + metric_stats.append("{:<}{}{}".format(k, sep, v_avg)) + return metric_stats + + def avg_statistics( + self, metric_name: str, sub_metric_name: Optional[str] = None, *args, **kwargs + ) -> float: + """ + This function computes the average statistics of a given metric. + + .. note:: + The statistics are stored in form of a dictionary and each key-value pair can be of string and number + OR string and dictionary of string and number. + + Examples: + {'loss': 10.0, 'top-1': 50.0} + {'loss': {'total_loss': 10.0, 'cls_loss': 2.0, 'reg_loss': 8.0}, 'mAP': 5.0} + + """ + avg_val = None + if metric_name in self.supported_metrics: + counter = self.metric_counters[metric_name] + v = self.metric_dict[metric_name] + + if metric_name == "iou": + if isinstance(v["inter"], np.ndarray): + inter = (v["inter"] * 1.0) / counter + union = (v["union"] * 1.0) / counter + iou = inter / union + if isinstance(iou, torch.Tensor): + iou = iou.cpu().numpy() + # Converting iou from [0, 1] to [0, 100] + # other metrics are by default in [0, 100 range] + avg_val = np.mean(iou) * 100.0 + avg_val = round(avg_val, self.round_places) + else: + logger.error("IOU computation is only supported using np.ndarray.") + + elif isinstance(v, Dict) and sub_metric_name is not None: + sub_metric_keys = list(v.keys()) + if sub_metric_name in sub_metric_keys: + avg_val = round( + (v[sub_metric_name] * 1.0) / counter, self.round_places + ) + else: + logger.error( + "{} not present in the dictionary. Available keys are: {}".format( + sub_metric_name, sub_metric_keys + ) + ) + elif isinstance(v, Number): + avg_val = round((v * 1.0) / counter, self.round_places) + + return avg_val + + def iter_summary( + self, + epoch: int, + n_processed_samples: int, + total_samples: int, + elapsed_time: float, + learning_rate: float or list, + ) -> None: + if self.is_master_node: + metric_stats = self.avg_statistics_all() + el_time_str = "Elapsed time: {:5.2f}".format(time.time() - elapsed_time) + if isinstance(learning_rate, float): + lr_str = "LR: {:1.6f}".format(learning_rate) + else: + learning_rate = [round(lr, 6) for lr in learning_rate] + lr_str = "LR: {}".format(learning_rate) + epoch_str = "Epoch: {:3d} [{:8d}/{:8d}]".format( + epoch, n_processed_samples, total_samples + ) + batch_str = "Avg. batch load time: {:1.3f}".format( + self.batch_time / self.batch_counter + ) + + stats_summary = [epoch_str] + stats_summary.extend(metric_stats) + stats_summary.append(lr_str) + stats_summary.append(batch_str) + stats_summary.append(el_time_str) + + summary_str = ", ".join(stats_summary) + logger.log(summary_str) + sys.stdout.flush() + + def epoch_summary(self, epoch: int, stage: Optional[str] = "Training") -> None: + if self.is_master_node: + metric_stats = self.avg_statistics_all(sep="=") + metric_stats_str = " || ".join(metric_stats) + logger.log("*** {} summary for epoch {}".format(stage.title(), epoch)) + print("\t {}".format(metric_stats_str)) + sys.stdout.flush() diff --git a/Adaptive Frequency Filters/metrics/topk_accuracy.py b/Adaptive Frequency Filters/metrics/topk_accuracy.py new file mode 100644 index 0000000..2a5f668 --- /dev/null +++ b/Adaptive Frequency Filters/metrics/topk_accuracy.py @@ -0,0 +1,30 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from torch import Tensor +from typing import Optional + +from . import register_stats_fn + + +@register_stats_fn(name="top1") +@register_stats_fn(name="top5") +def top_k_accuracy( + output: Tensor, target: Tensor, top_k: Optional[tuple] = (1,) +) -> list: + maximum_k = max(top_k) + batch_size = target.shape[0] + + _, pred = output.topk(maximum_k, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + + results = [] + for k in top_k: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + acc_k = correct_k.mul_(100.0 / batch_size) + results.append(acc_k) + return results diff --git a/Adaptive Frequency Filters/optim/__init__.py b/Adaptive Frequency Filters/optim/__init__.py new file mode 100644 index 0000000..f3b551e --- /dev/null +++ b/Adaptive Frequency Filters/optim/__init__.py @@ -0,0 +1,169 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import importlib +from typing import List, Dict +import torch.nn +import argparse + +from utils import logger + +from .base_optim import BaseOptim + +OPTIM_REGISTRY = {} + + +def register_optimizer(name: str): + def register_optimizer_class(cls): + if name in OPTIM_REGISTRY: + raise ValueError("Cannot register duplicate optimizer ({})".format(name)) + + if not issubclass(cls, BaseOptim): + raise ValueError( + "Optimizer ({}: {}) must extend BaseOptim".format(name, cls.__name__) + ) + + OPTIM_REGISTRY[name] = cls + return cls + + return register_optimizer_class + + +def check_trainable_parameters(model: torch.nn.Module, model_params: List) -> None: + """Helper function to check if any model parameters w/ gradients are not part of model_params""" + + # get model parameter names + model_trainable_params = [] + for p_name, param in model.named_parameters(): + if param.requires_grad: + model_trainable_params.append(p_name) + + initialized_params = [] + for param_info in model_params: + if not isinstance(param_info, Dict): + logger.error( + "Expected format is a Dict with three keys: params, weight_decay, param_names" + ) + + if not {"params", "weight_decay", "param_names"}.issubset(param_info.keys()): + logger.error( + "Parameter dict should have three keys: params, weight_decay, param_names" + ) + + param_names = param_info.pop("param_names") + if isinstance(param_names, List): + initialized_params.extend(param_names) + elif isinstance(param_names, str): + initialized_params.append(param_names) + else: + raise NotImplementedError + + uninitialized_params = set(model_trainable_params) ^ set(initialized_params) + if len(uninitialized_params) > 0: + logger.error( + "Following parameters are defined in the model, but won't be part of optimizer. " + "Please check get_trainable_parameters function. " + "Use --optim.bypass-parameters-check flag to bypass this check. " + "Parameter list = {}".format(uninitialized_params) + ) + + +def remove_param_name_key(model_params: List) -> None: + """Helper function to remove param_names key from model_params""" + for param_info in model_params: + if not isinstance(param_info, Dict): + logger.error( + "Expected format is a Dict with three keys: params, weight_decay, param_names" + ) + + if not {"params", "weight_decay", "param_names"}.issubset(param_info.keys()): + logger.error( + "Parameter dict should have three keys: params, weight_decay, param_names" + ) + + param_info.pop("param_names") + + +def build_optimizer(model: torch.nn.Module, opts, *args, **kwargs) -> BaseOptim: + optim_name = getattr(opts, "optim.name", "sgd").lower() + optimizer = None + weight_decay = getattr(opts, "optim.weight_decay", 0.0) + no_decay_bn_filter_bias = getattr(opts, "optim.no_decay_bn_filter_bias", False) + + unwrapped_model = model.module if hasattr(model, "module") else model + + model_params, lr_mult = unwrapped_model.get_trainable_parameters( + weight_decay=weight_decay, + no_decay_bn_filter_bias=no_decay_bn_filter_bias, + *args, + **kwargs + ) + + # check to ensure that all trainable model parameters are passed to the model + if not getattr(opts, "optim.bypass_parameters_check", False): + check_trainable_parameters(model=unwrapped_model, model_params=model_params) + else: + remove_param_name_key(model_params=model_params) + + setattr(opts, "optim.lr_multipliers", lr_mult) + if optim_name in OPTIM_REGISTRY: + optimizer = OPTIM_REGISTRY[optim_name](opts, model_params) + else: + supp_list = list(OPTIM_REGISTRY.keys()) + supp_str = ( + "Optimizer ({}) not yet supported. \n Supported optimizers are:".format( + optim_name + ) + ) + for i, m_name in enumerate(supp_list): + supp_str += "\n\t {}: {}".format(i, logger.color_text(m_name)) + logger.error(supp_str) + + return optimizer + + +def general_optim_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group("optimizer", "Optimizer related arguments") + group.add_argument("--optim.name", default="sgd", help="Which optimizer") + group.add_argument("--optim.eps", type=float, default=1e-8, help="Optimizer eps") + group.add_argument( + "--optim.weight-decay", default=4e-5, type=float, help="Weight decay" + ) + group.add_argument( + "--optim.no-decay-bn-filter-bias", + action="store_true", + help="No weight decay in normalization layers and bias", + ) + group.add_argument( + "--optim.bypass-parameters-check", + action="store_true", + help="Bypass parameter check when creating optimizer", + ) + return parser + + +def arguments_optimizer(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser = general_optim_args(parser=parser) + + # add optim specific arguments + for k, v in OPTIM_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + + return parser + + +# automatically import the optimizers +optim_dir = os.path.dirname(__file__) +for file in os.listdir(optim_dir): + path = os.path.join(optim_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + optim_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("optim." + optim_name) diff --git a/Adaptive Frequency Filters/optim/adam.py b/Adaptive Frequency Filters/optim/adam.py new file mode 100644 index 0000000..64c32f3 --- /dev/null +++ b/Adaptive Frequency Filters/optim/adam.py @@ -0,0 +1,69 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse +from torch.optim import Adam + +from . import register_optimizer +from .base_optim import BaseOptim + + +@register_optimizer("adam") +class AdamOptimizer(BaseOptim, Adam): + """ + `Adam `_ optimizer + """ + + def __init__(self, opts, model_params) -> None: + BaseOptim.__init__(self, opts=opts) + beta1 = getattr(opts, "optim.adam.beta1", 0.9) + beta2 = getattr(opts, "optim.adam.beta2", 0.98) + ams_grad = getattr(opts, "optim.adam.amsgrad", False) + eps = getattr(opts, "optim.adam.eps", None) + Adam.__init__( + self, + params=model_params, + lr=self.lr, + betas=(beta1, beta2), + eps=self.eps if eps is None else eps, + weight_decay=self.weight_decay, + amsgrad=ams_grad, + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group("ADAM arguments", "ADAM arguments") + group.add_argument( + "--optim.adam.beta1", type=float, default=0.9, help="Adam Beta1" + ) + group.add_argument( + "--optim.adam.beta2", type=float, default=0.98, help="Adam Beta2" + ) + group.add_argument( + "--optim.adam.amsgrad", action="store_true", help="Use AMSGrad in ADAM" + ) + group.add_argument( + "--optim.adam.eps", type=float, default=None, help="Epsilon in Adam" + ) + return parser + + def __repr__(self) -> str: + group_dict = dict() + for i, group in enumerate(self.param_groups): + for key in sorted(group.keys()): + if key == "params": + continue + if key not in group_dict: + group_dict[key] = [group[key]] + else: + group_dict[key].append(group[key]) + + format_string = self.__class__.__name__ + " (" + format_string += "\n" + for k, v in group_dict.items(): + format_string += "\t {0}: {1}\n".format(k, v) + format_string += ")" + return format_string diff --git a/Adaptive Frequency Filters/optim/adamw.py b/Adaptive Frequency Filters/optim/adamw.py new file mode 100644 index 0000000..4c30ee0 --- /dev/null +++ b/Adaptive Frequency Filters/optim/adamw.py @@ -0,0 +1,69 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse +from torch.optim import AdamW + +from . import register_optimizer +from .base_optim import BaseOptim + + +@register_optimizer("adamw") +class AdamWOptimizer(BaseOptim, AdamW): + """ + `AdamW `_ optimizer + """ + + def __init__(self, opts, model_params) -> None: + BaseOptim.__init__(self, opts=opts) + beta1 = getattr(opts, "optim.adamw.beta1", 0.9) + beta2 = getattr(opts, "optim.adamw.beta2", 0.98) + ams_grad = getattr(opts, "optim.adamw.amsgrad", False) + eps = getattr(opts, "optim.adamw.eps", None) + AdamW.__init__( + self, + params=model_params, + lr=self.lr, + betas=(beta1, beta2), + eps=self.eps if eps is None else eps, + weight_decay=self.weight_decay, + amsgrad=ams_grad, + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group("AdamW arguments", "AdamW arguments") + group.add_argument( + "--optim.adamw.beta1", type=float, default=0.9, help="Adam Beta1" + ) + group.add_argument( + "--optim.adamw.beta2", type=float, default=0.98, help="Adam Beta2" + ) + group.add_argument( + "--optim.adamw.amsgrad", action="store_true", help="Use AMSGrad in ADAM" + ) + group.add_argument( + "--optim.adamw.eps", type=float, default=None, help="Epsilon in Adam" + ) + return parser + + def __repr__(self) -> str: + group_dict = dict() + for i, group in enumerate(self.param_groups): + for key in sorted(group.keys()): + if key == "params": + continue + if key not in group_dict: + group_dict[key] = [group[key]] + else: + group_dict[key].append(group[key]) + + format_string = self.__class__.__name__ + " (" + format_string += "\n" + for k, v in group_dict.items(): + format_string += "\t {0}: {1}\n".format(k, v) + format_string += ")" + return format_string diff --git a/Adaptive Frequency Filters/optim/base_optim.py b/Adaptive Frequency Filters/optim/base_optim.py new file mode 100644 index 0000000..241f0be --- /dev/null +++ b/Adaptive Frequency Filters/optim/base_optim.py @@ -0,0 +1,20 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse + + +class BaseOptim(object): + """Base class for optimizer""" + + def __init__(self, opts) -> None: + self.eps = 1e-8 + self.lr = getattr(opts, "scheduler.lr", 0.1) + self.weight_decay = getattr(opts, "optim.weight_decay", 4e-5) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + return parser diff --git a/Adaptive Frequency Filters/optim/scheduler/__init__.py b/Adaptive Frequency Filters/optim/scheduler/__init__.py new file mode 100644 index 0000000..472a1fe --- /dev/null +++ b/Adaptive Frequency Filters/optim/scheduler/__init__.py @@ -0,0 +1,118 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import importlib +from utils import logger +import argparse + +from .base_scheduler import BaseLRScheduler + +SCHEDULER_REGISTRY = {} + + +def register_scheduler(name: str): + def register_scheduler_class(cls): + if name in SCHEDULER_REGISTRY: + raise ValueError("Cannot register duplicate scheduler ({})".format(name)) + + if not issubclass(cls, BaseLRScheduler): + raise ValueError( + "LR Scheduler ({}: {}) must extend BaseLRScheduler".format( + name, cls.__name__ + ) + ) + + SCHEDULER_REGISTRY[name] = cls + return cls + + return register_scheduler_class + + +def build_scheduler(opts) -> BaseLRScheduler: + scheduler_name = getattr(opts, "scheduler.name", "cosine").lower() + lr_scheduler = None + if scheduler_name in SCHEDULER_REGISTRY: + lr_scheduler = SCHEDULER_REGISTRY[scheduler_name](opts) + else: + supp_list = list(SCHEDULER_REGISTRY.keys()) + supp_str = ( + "LR Scheduler ({}) not yet supported. \n Supported schedulers are:".format( + scheduler_name + ) + ) + for i, m_name in enumerate(supp_list): + supp_str += "\n\t {}: {}".format(i, logger.color_text(m_name)) + logger.error(supp_str) + + return lr_scheduler + + +def general_lr_sch_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="LR scheduler arguments", description="LR scheduler arguments" + ) + + group.add_argument( + "--scheduler.name", type=str, default="cosine", help="LR scheduler name" + ) + group.add_argument("--scheduler.lr", type=float, default=0.1, help="Learning rate") + group.add_argument( + "--scheduler.max-epochs", + type=int, + default=None, + help="Max. epochs for training", + ) + group.add_argument( + "--scheduler.max-iterations", + type=int, + default=None, + help="Max. iterations for training", + ) + group.add_argument( + "--scheduler.warmup-iterations", + type=int, + default=None, + help="Warm-up iterations", + ) + group.add_argument( + "--scheduler.warmup-init-lr", type=float, default=1e-7, help="Warm-up init lr" + ) + group.add_argument( + "--scheduler.is-iteration-based", + action="store_true", + help="Is iteration type or epoch type", + ) + + group.add_argument( + "--scheduler.adjust-period-for-epochs", + action="store_true", + help="Adjust the period for epoch-based scheduler.", + ) + + return parser + + +def arguments_scheduler(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser = general_lr_sch_args(parser=parser) + + # add scheduler specific arguments + for k, v in SCHEDULER_REGISTRY.items(): + parser = v.add_arguments(parser=parser) + return parser + + +# automatically import the LR schedulers +lr_sch_dir = os.path.dirname(__file__) +for file in os.listdir(lr_sch_dir): + path = os.path.join(lr_sch_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + lr_sch_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("optim.scheduler." + lr_sch_name) diff --git a/Adaptive Frequency Filters/optim/scheduler/base_scheduler.py b/Adaptive Frequency Filters/optim/scheduler/base_scheduler.py new file mode 100644 index 0000000..9383c15 --- /dev/null +++ b/Adaptive Frequency Filters/optim/scheduler/base_scheduler.py @@ -0,0 +1,58 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse + + +class BaseLRScheduler(object): + def __init__(self, opts) -> None: + warmup_iterations = getattr(opts, "scheduler.warmup_iterations", None) + super().__init__() + self.opts = opts + self.round_places = 8 + self.lr_multipliers = getattr(opts, "optim.lr_multipliers", None) + + self.warmup_iterations = ( + max(warmup_iterations, 0) if warmup_iterations is not None else 0 + ) + + warmup_init_lr = getattr(opts, "scheduler.warmup_init_lr", 1e-7) + self.warmup_init_lr = warmup_init_lr + + # Because of variable batch sizes, we can't determine exact number of epochs in warm-up phase. This + # may result in different LR schedules when we run epoch- and iteration-based schedulers. + # To reduce these differences, we use adjust_period_for_epochs arguments. + # For epoch-based scheduler, this parameter value should be enabled. + self.adjust_period = getattr(opts, "scheduler.adjust_period_for_epochs", False) + self.warmup_epochs = 0 + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + return parser + + def get_lr(self, epoch: int, curr_iter: int): + raise NotImplementedError + + def update_lr(self, optimizer, epoch: int, curr_iter: int): + lr = self.get_lr(epoch=epoch, curr_iter=curr_iter) + lr = max(0.0, lr) + if self.lr_multipliers is not None: + assert len(self.lr_multipliers) == len(optimizer.param_groups) + for g_id, param_group in enumerate(optimizer.param_groups): + param_group["lr"] = round( + lr * self.lr_multipliers[g_id], self.round_places + ) + else: + for param_group in optimizer.param_groups: + param_group["lr"] = round(lr, self.round_places) + return optimizer + + @staticmethod + def retrieve_lr(optimizer) -> list: + lr_list = [] + for param_group in optimizer.param_groups: + lr_list.append(param_group["lr"]) + return lr_list diff --git a/Adaptive Frequency Filters/optim/scheduler/cosine.py b/Adaptive Frequency Filters/optim/scheduler/cosine.py new file mode 100644 index 0000000..ac52d5e --- /dev/null +++ b/Adaptive Frequency Filters/optim/scheduler/cosine.py @@ -0,0 +1,93 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from . import register_scheduler +from .base_scheduler import BaseLRScheduler +import argparse +import math + + +@register_scheduler("cosine") +class CosineScheduler(BaseLRScheduler): + """ + Cosine learning rate scheduler: https://arxiv.org/abs/1608.03983 + """ + + def __init__(self, opts, **kwargs) -> None: + is_iter_based = getattr(opts, "scheduler.is_iteration_based", True) + super(CosineScheduler, self).__init__(opts=opts) + + max_iterations = getattr(opts, "scheduler.max_iterations", 150000) + + self.min_lr = getattr(opts, "scheduler.cosine.min_lr", 1e-5) + self.max_lr = getattr(opts, "scheduler.cosine.max_lr", 0.4) + + if self.warmup_iterations > 0: + self.warmup_step = ( + self.max_lr - self.warmup_init_lr + ) / self.warmup_iterations + + self.period = ( + max_iterations - self.warmup_iterations + 1 + if is_iter_based + else getattr(opts, "scheduler.max_epochs", 350) + ) + + self.is_iter_based = is_iter_based + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="Cosine LR arguments", description="Cosine LR arguments" + ) + + group.add_argument( + "--scheduler.cosine.min-lr", + type=float, + default=1e-5, + help="Minimum LR in Cosine LR scheduler", + ) + group.add_argument( + "--scheduler.cosine.max-lr", + type=float, + default=0.1, + help="Maximum LR in Cosine LR scheduler", + ) + return parser + + def get_lr(self, epoch: int, curr_iter: int) -> float: + if curr_iter < self.warmup_iterations: + curr_lr = self.warmup_init_lr + curr_iter * self.warmup_step + self.warmup_epochs = epoch + else: + if self.is_iter_based: + curr_iter = curr_iter - self.warmup_iterations + curr_lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * ( + 1 + math.cos(math.pi * curr_iter / self.period) + ) + else: + adjust_num = self.warmup_epochs + 1 if self.adjust_period else 0 + adjust_den = self.warmup_epochs if self.adjust_period else 0 + curr_lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * ( + 1 + + math.cos( + math.pi * (epoch - adjust_num) / (self.period - adjust_den) + ) + ) + return max(0.0, curr_lr) + + def __repr__(self) -> str: + repr_str = "{}(".format(self.__class__.__name__) + repr_str += "\n \t min_lr={}\n \t max_lr={}\n \t period={}".format( + self.min_lr, self.max_lr, self.period + ) + if self.warmup_iterations > 0: + repr_str += "\n \t warmup_init_lr={}\n \t warmup_iters={}".format( + self.warmup_init_lr, self.warmup_iterations + ) + + repr_str += "\n )" + return repr_str diff --git a/Adaptive Frequency Filters/optim/scheduler/cyclic.py b/Adaptive Frequency Filters/optim/scheduler/cyclic.py new file mode 100644 index 0000000..aca03f0 --- /dev/null +++ b/Adaptive Frequency Filters/optim/scheduler/cyclic.py @@ -0,0 +1,180 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse +import math +from utils import logger +import numpy as np + +from . import register_scheduler +from .base_scheduler import BaseLRScheduler + +SUPPORTED_LAST_CYCLES = ["cosine", "linear"] + + +@register_scheduler("cyclic") +class CyclicLRScheduler(BaseLRScheduler): + """ + Cyclic LR: https://arxiv.org/abs/1811.11431 + """ + + def __init__(self, opts, **kwargs) -> None: + + cycle_steps = getattr(opts, "scheduler.cyclic.steps", [25]) + if cycle_steps is not None and isinstance(cycle_steps, int): + cycle_steps = [cycle_steps] + gamma = getattr(opts, "scheduler.cyclic.gamma", 0.5) + anneal_type = getattr(opts, "scheduler.cyclic.last_cycle_type", "linear") + min_lr = getattr(opts, "scheduler.cyclic.min_lr", 0.1) + end_lr = getattr(opts, "scheduler.cyclic.last_cycle_end_lr", 1e-3) + ep_per_cycle = getattr(opts, "scheduler.cyclic.epochs_per_cycle", 5) + warmup_iterations = getattr(opts, "scheduler.warmup_iterations", 0) + n_cycles = getattr(opts, "scheduler.cyclic.total_cycles", 10) - 1 + max_epochs = getattr(opts, "scheduler.max_epochs", 100) + + if anneal_type not in SUPPORTED_LAST_CYCLES: + logger.error( + "Supported anneal types for {} are: {}".format( + self.__class__.__name__, SUPPORTED_LAST_CYCLES + ) + ) + if min_lr < end_lr: + logger.error( + "Min LR should be greater than end LR. Got: {} and {}".format( + min_lr, end_lr + ) + ) + + super(CyclicLRScheduler, self).__init__(opts=opts) + self.min_lr = min_lr + self.cycle_length = ep_per_cycle + self.end_lr = end_lr + self.max_lr = self.min_lr * self.cycle_length + self.last_cycle_anneal_type = anneal_type + + if self.warmup_iterations > 0: + self.warmup_step = ( + self.min_lr - self.warmup_init_lr + ) / self.warmup_iterations + + self.n_cycles = n_cycles + + self.cyclic_epochs = self.cycle_length * self.n_cycles + self.max_epochs = max_epochs + self.last_cycle_epochs = self.max_epochs - self.cyclic_epochs + + assert self.max_epochs == self.cyclic_epochs + self.last_cycle_epochs + + self.steps = [self.max_epochs] if cycle_steps is None else cycle_steps + self.gamma = gamma if cycle_steps is not None else 1 + + self._lr_per_cycle() + self.epochs_lr_stepped = [] + + def _lr_per_cycle(self) -> None: + lrs = list( + np.linspace(self.max_lr, self.min_lr, self.cycle_length, dtype=np.float) + ) + lrs = [lrs[-1]] + lrs[:-1] + self.cycle_lrs = lrs + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="Cyclic LR arguments", description="Cyclic LR arguments" + ) + group.add_argument( + "--scheduler.cyclic.min-lr", + default=0.1, + type=float, + help="Min. lr for a cycle", + ) + group.add_argument( + "--scheduler.cyclic.last-cycle-end-lr", + default=1e-3, + type=float, + help="End LR for the last cycle", + ) + group.add_argument( + "--scheduler.cyclic.total-cycles", + default=11, + type=int, + help="Number of cycles. Default is 10", + ) + group.add_argument( + "--scheduler.cyclic.epochs-per-cycle", + default=5, + type=int, + help="Number of epochs per cycle. Default is 5", + ) + group.add_argument( + "--scheduler.cyclic.steps", + default=None, + type=int, + nargs="+", + help="steps at which LR should be decreased", + ) + group.add_argument( + "--scheduler.cyclic.gamma", + default=0.5, + type=float, + help="Factor by which LR should be decreased", + ) + group.add_argument( + "--scheduler.cyclic.last-cycle-type", + default="linear", + type=str, + choices=SUPPORTED_LAST_CYCLES, + help="Annealing in last cycle", + ) + return parser + + def get_lr(self, epoch: int, curr_iter: int) -> float: + if curr_iter < self.warmup_iterations: + curr_lr = self.warmup_init_lr + curr_iter * self.warmup_step + else: + if epoch <= self.cyclic_epochs: + if epoch in self.steps and epoch not in self.epochs_lr_stepped: + self.min_lr *= self.gamma ** (self.steps.index(epoch) + 1) + self.max_lr *= self.gamma ** (self.steps.index(epoch) + 1) + self._lr_per_cycle() + self.epochs_lr_stepped.append(epoch) + idx = epoch % self.cycle_length + curr_lr = self.cycle_lrs[idx] + else: + base_lr = self.min_lr + if self.last_cycle_anneal_type == "linear": + lr_step = (base_lr - self.end_lr) / self.last_cycle_epochs + curr_lr = base_lr - (epoch - self.cyclic_epochs + 1) * lr_step + elif self.last_cycle_anneal_type == "cosine": + curr_epoch = epoch - self.cyclic_epochs + period = self.max_epochs - self.cyclic_epochs + 1 + curr_lr = self.end_lr + 0.5 * (base_lr - self.end_lr) * ( + 1 + math.cos(math.pi * curr_epoch / period) + ) + else: + raise NotImplementedError + return max(0.0, curr_lr) + + def __repr__(self): + repr_str = ( + "{}(\n \t C={},\n \t C_length={},\n \t C_last={},\n \t Total_Epochs={}, " + "\n \t steps={},\n \t gamma={},\n \t last_cycle_anneal_method={} " + "\n \t min_lr={}, \n\t max_lr={}, \n\t end_lr={}\n)".format( + self.__class__.__name__, + self.n_cycles, + self.cycle_length, + self.last_cycle_epochs, + self.max_epochs, + self.steps, + self.gamma, + self.last_cycle_anneal_type, + self.min_lr, + self.min_lr * self.cycle_length, + self.end_lr, + ) + ) + return repr_str diff --git a/Adaptive Frequency Filters/optim/scheduler/fixed.py b/Adaptive Frequency Filters/optim/scheduler/fixed.py new file mode 100644 index 0000000..b53864b --- /dev/null +++ b/Adaptive Frequency Filters/optim/scheduler/fixed.py @@ -0,0 +1,69 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from . import register_scheduler +from .base_scheduler import BaseLRScheduler +import argparse +import math + + +@register_scheduler("fixed") +class FixedLRScheduler(BaseLRScheduler): + """ + Fixed learning rate scheduler with optional linear warm-up strategy + """ + + def __init__(self, opts, **kwargs) -> None: + is_iter_based = getattr(opts, "scheduler.is_iteration_based", True) + super(FixedLRScheduler, self).__init__(opts=opts) + + max_iterations = getattr(opts, "scheduler.max_iterations", 150000) + + self.fixed_lr = getattr(opts, "scheduler.fixed.lr", None) + assert self.fixed_lr is not None + + if self.warmup_iterations > 0: + self.warmup_step = ( + self.fixed_lr - self.warmup_init_lr + ) / self.warmup_iterations + + self.period = ( + max_iterations - self.warmup_iterations + 1 + if is_iter_based + else getattr(opts, "scheduler.max_epochs", 350) + ) + + self.is_iter_based = is_iter_based + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="Fixed LR arguments", description="Fixed LR arguments" + ) + + group.add_argument( + "--scheduler.fixed.lr", type=float, default=None, help="LR value" + ) + + return parser + + def get_lr(self, epoch: int, curr_iter: int) -> float: + if curr_iter < self.warmup_iterations: + curr_lr = self.warmup_init_lr + curr_iter * self.warmup_step + else: + curr_lr = self.fixed_lr + return max(0.0, curr_lr) + + def __repr__(self) -> str: + repr_str = "{}(".format(self.__class__.__name__) + repr_str += "\n\tlr={}".format(self.fixed_lr) + if self.warmup_iterations > 0: + repr_str += "\n\twarmup_init_lr={}\n\twarmup_iters={}".format( + self.warmup_init_lr, self.warmup_iterations + ) + + repr_str += "\n )" + return repr_str diff --git a/Adaptive Frequency Filters/optim/scheduler/multi_step.py b/Adaptive Frequency Filters/optim/scheduler/multi_step.py new file mode 100644 index 0000000..4db96c6 --- /dev/null +++ b/Adaptive Frequency Filters/optim/scheduler/multi_step.py @@ -0,0 +1,95 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from . import register_scheduler +from .base_scheduler import BaseLRScheduler +import argparse +import math + + +@register_scheduler("multi_step") +class MultiStepLRScheduler(BaseLRScheduler): + """ + Multi-step learning rate scheduler with optional linear warm-up strategy + """ + + def __init__(self, opts, **kwargs) -> None: + is_iter_based = getattr(opts, "scheduler.is_iteration_based", True) + super().__init__(opts=opts) + + max_iterations = getattr(opts, "scheduler.max_iterations", 150000) + + self.lr = getattr(opts, "scheduler.multi_step.lr", None) + assert self.lr is not None + + if self.warmup_iterations > 0: + self.warmup_step = (self.lr - self.warmup_init_lr) / self.warmup_iterations + + milestones = getattr(opts, "scheduler.multi_step.milestones", None) + if milestones is None: + milestones = [-1] + elif isinstance(milestones, int): + milestones = [milestones] + + self.milestones = sorted( + list(set(milestones)) + ) # remove duplicates and sort them + self.gamma = getattr(opts, "scheduler.multi_step.gamma", 1.0) + + self.period = ( + max_iterations - self.warmup_iterations + 1 + if is_iter_based + else getattr(opts, "scheduler.max_epochs", 350) + ) + + self.is_iter_based = is_iter_based + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="{} arguments".format(cls.__name__), + description="{} arguments".format(cls.__name__), + ) + + group.add_argument( + "--scheduler.multi-step.lr", type=float, default=0.1, help="LR value" + ) + group.add_argument( + "--scheduler.multi-step.gamma", + type=float, + default=None, + help="Decay LR value by this factor", + ) + group.add_argument( + "--scheduler.multi-step.milestones", + type=int, + nargs="+", + default=None, + help="Decay LR value at these epoch", + ) + return parser + + def get_lr(self, epoch: int, curr_iter: int) -> float: + if curr_iter < self.warmup_iterations: + return max(0.0, self.warmup_init_lr + curr_iter * self.warmup_step) + else: + if epoch in self.milestones: + self.lr = self.lr * self.gamma + self.milestones.remove(epoch) + return max(0.0, self.lr) + + def __repr__(self) -> str: + repr_str = "{}(".format(self.__class__.__name__) + repr_str += "\n\tlr={}\n\tmilestones={}\n\tgamma={}".format( + self.lr, self.milestones, self.gamma + ) + if self.warmup_iterations > 0: + repr_str += "\n\twarmup_init_lr={}\n\twarmup_iters={}".format( + self.warmup_init_lr, self.warmup_iterations + ) + + repr_str += "\n )" + return repr_str diff --git a/Adaptive Frequency Filters/optim/scheduler/polynomial.py b/Adaptive Frequency Filters/optim/scheduler/polynomial.py new file mode 100644 index 0000000..7dfba1a --- /dev/null +++ b/Adaptive Frequency Filters/optim/scheduler/polynomial.py @@ -0,0 +1,94 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse + +from . import register_scheduler +from .base_scheduler import BaseLRScheduler + + +@register_scheduler("polynomial") +class PolynomialScheduler(BaseLRScheduler): + """ + Polynomial LR scheduler + """ + + def __init__(self, opts, **kwargs) -> None: + is_iter_based = getattr(opts, "scheduler.is_iteration_based", False) + max_iterations = getattr(opts, "scheduler.max_iterations", 50000) + max_epochs = getattr(opts, "scheduler.max_epochs", 350) + + super(PolynomialScheduler, self).__init__(opts=opts) + + self.start_lr = getattr(opts, "scheduler.polynomial.start_lr", 0.1) + self.end_lr = getattr(opts, "scheduler.polynomial.end_lr", 0.0) + self.power = getattr(opts, "scheduler.polynomial.power", 0.9) + + if self.warmup_iterations > 0: + self.warmup_step = ( + self.start_lr - self.warmup_init_lr + ) / self.warmup_iterations + + self.is_iter_based = is_iter_based + self.max_iterations = max_iterations - self.warmup_iterations + 1 + self.max_epochs = max_epochs + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="Polynomial LR arguments", description="Polynomial LR arguments" + ) + + group.add_argument( + "--scheduler.polynomial.power", + type=float, + default=0.9, + help="Polynomial power", + ) + group.add_argument( + "--scheduler.polynomial.start-lr", + type=float, + default=0.1, + help="Start LR in Poly LR scheduler", + ) + group.add_argument( + "--scheduler.polynomial.end-lr", + type=float, + default=0.0, + help="End LR in Poly LR scheduler", + ) + + return parser + + def get_lr(self, epoch: int, curr_iter: int) -> float: + if curr_iter < self.warmup_iterations: + curr_lr = self.warmup_init_lr + curr_iter * self.warmup_step + self.warmup_epochs = epoch + else: + if self.is_iter_based: + factor = (curr_iter - self.warmup_iterations) / self.max_iterations + else: + adjust_num = self.warmup_epochs + 1 if self.adjust_period else 0 + adjust_den = self.warmup_epochs if self.adjust_period else 0 + factor = (epoch - adjust_num) / (self.max_epochs - adjust_den) + curr_lr = (self.start_lr - self.end_lr) * ( + (1.0 - factor) ** self.power + ) + self.end_lr + return max(0.0, curr_lr) + + def __repr__(self) -> str: + repr_str = "{}(".format(self.__class__.__name__) + repr_str += "\n\tpower={}\n\tstart_lr={}".format(self.power, self.start_lr) + if self.end_lr > 0: + repr_str += "\n\tend_lr={}".format(self.end_lr) + + if self.warmup_iterations > 0: + repr_str += "\n\twarmup_init_lr={}\n\twarmup_iters={}".format( + self.warmup_init_lr, self.warmup_iterations + ) + + repr_str += "\n )" + return repr_str diff --git a/Adaptive Frequency Filters/optim/sgd.py b/Adaptive Frequency Filters/optim/sgd.py new file mode 100644 index 0000000..81c6977 --- /dev/null +++ b/Adaptive Frequency Filters/optim/sgd.py @@ -0,0 +1,61 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse +from torch.optim import SGD + +from . import register_optimizer +from .base_optim import BaseOptim + + +@register_optimizer("sgd") +class SGDOptimizer(BaseOptim, SGD): + """ + `SGD `_ optimizer + """ + + def __init__(self, opts, model_params) -> None: + BaseOptim.__init__(self, opts=opts) + nesterov = getattr(opts, "optim.sgd.nesterov", False) + momentum = getattr(opts, "optim.sgd.momentum", 0.9) + + SGD.__init__( + self, + params=model_params, + lr=self.lr, + momentum=momentum, + weight_decay=self.weight_decay, + nesterov=nesterov, + ) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group("SGD arguments", "SGD arguments") + group.add_argument( + "--optim.sgd.momentum", default=0.9, type=float, help="Momemtum in SGD" + ) + group.add_argument( + "--optim.sgd.nesterov", action="store_true", help="Use nesterov in SGD" + ) + return parser + + def __repr__(self) -> str: + group_dict = dict() + for i, group in enumerate(self.param_groups): + for key in sorted(group.keys()): + if key == "params": + continue + if key not in group_dict: + group_dict[key] = [group[key]] + else: + group_dict[key].append(group[key]) + + format_string = self.__class__.__name__ + " (" + format_string += "\n" + for k, v in group_dict.items(): + format_string += "\t {0}: {1}\n".format(k, v) + format_string += ")" + return format_string diff --git a/Adaptive Frequency Filters/options/__init__.py b/Adaptive Frequency Filters/options/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Adaptive Frequency Filters/options/opts.py b/Adaptive Frequency Filters/options/opts.py new file mode 100644 index 0000000..a2ff603 --- /dev/null +++ b/Adaptive Frequency Filters/options/opts.py @@ -0,0 +1,522 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse +from typing import Optional + +from common import SUPPORTED_MODALITIES +from affnet import modeling_arguments +from data.collate_fns import arguments_collate_fn +from data.datasets import arguments_dataset +from data.sampler import arguments_sampler +from data.text_tokenizer import arguments_tokenizer +from data.transforms import arguments_augmentation +from data.video_reader import arguments_video_reader +from loss_fn import arguments_loss_fn +from metrics import arguments_stats +from optim import arguments_optimizer +from optim.scheduler import arguments_scheduler +from options.utils import load_config_file +from utils import logger + + +class ParseKwargs(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + # convert values into dict + override_dict = {} + for val in values: + if val.find("=") < 0: + logger.error( + "For override arguments, a key-value pair of the form key=value is expected. Got: {}".format( + val + ) + ) + val_list = val.split("=") + if len(val_list) != 2: + logger.error( + "For override arguments, a key-value pair of the form key=value is expected with only one value per key. Got: {}".format( + val + ) + ) + override_dict[val_list[0]] = val_list[1] + + # determine the type of each value from parser actions and set accordingly + options = parser._actions + for option in options: + option_dest = option.dest + if option_dest in override_dict: + val = override_dict[option_dest] + if type(option.default) == bool and option.nargs == 0: + # Boolean argument + # value could be false, False, true, True + override_dict[option_dest] = ( + True if val.lower().find("true") > -1 else False + ) + elif option.nargs is None: + # when nargs is not defined, it is usually a string, int, and float. + override_dict[option_dest] = option.type(val) + elif option.nargs in ["+", "*"]: + # for list, we expect value to be comma separated + val_list = val.split(",") + override_dict[option_dest] = [option.type(v) for v in val_list] + else: + logger.error( + "Following option is not yet supported for overriding. Please specify in config file. Got: {}".format( + option + ) + ) + setattr(namespace, "override_args", override_dict) + + +def arguments_common(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="Common arguments", description="Common arguments" + ) + + group.add_argument("--common.seed", type=int, default=0, help="Random seed") + group.add_argument( + "--common.config-file", type=str, default=None, help="Configuration file" + ) + group.add_argument( + "--common.results-loc", + type=str, + default="results", + help="Directory where results will be stored", + ) + group.add_argument( + "--common.run-label", + type=str, + default="run_1", + help="Label id for the current run", + ) + group.add_argument( + "--common.eval-stage-name", + type=str, + default="evaluation", + help="Name to be used while logging in evaluation stage.", + ) + + group.add_argument( + "--common.resume", type=str, default=None, help="Resume location" + ) + group.add_argument( + "--common.finetune_imagenet1k", + type=str, + default=None, + help="Checkpoint location to be used for finetuning", + ) + group.add_argument( + "--common.finetune_imagenet1k-ema", + type=str, + default=None, + help="EMA Checkpoint location to be used for finetuning", + ) + + group.add_argument( + "--common.mixed-precision", action="store_true", help="Mixed precision training" + ) + group.add_argument( + "--common.mixed-precision-dtype", + type=str, + default="float16", + help="Mixed precision training data type", + ) + group.add_argument( + "--common.accum-freq", + type=int, + default=1, + help="Accumulate gradients for this number of iterations", + ) + group.add_argument( + "--common.accum-after-epoch", + type=int, + default=0, + help="Start accumulation after this many epochs", + ) + group.add_argument( + "--common.log-freq", + type=int, + default=100, + help="Display after these many iterations", + ) + group.add_argument( + "--common.auto-resume", + action="store_true", + help="Resume training from the last checkpoint", + ) + group.add_argument( + "--common.grad-clip", type=float, default=None, help="Gradient clipping value" + ) + group.add_argument( + "--common.k-best-checkpoints", + type=int, + default=5, + help="Keep k-best checkpoints", + ) + group.add_argument( + "--common.save-all-checkpoints", + action="store_true", + default=False, + help="If True, will save checkpoints from all epochs", + ) + + group.add_argument( + "--common.inference-modality", + type=str, + default="image", + choices=SUPPORTED_MODALITIES, + help="Inference modality. Image or videos", + ) + + group.add_argument( + "--common.channels-last", + action="store_true", + default=False, + help="Use channel last format during training. " + "Note 1: that some models may not support it, so we recommend to use it with caution" + "Note 2: Channel last format does not work with 1-, 2-, and 3- tensors. " + "Therefore, we support it via custom collate functions", + ) + + group.add_argument( + "--common.tensorboard-logging", + action="store_true", + help="Enable tensorboard logging", + ) + group.add_argument( + "--common.bolt-logging", action="store_true", help="Enable bolt logging" + ) + + group.add_argument( + "--common.override-kwargs", + nargs="*", + action=ParseKwargs, + help="Override arguments. Example. To override the value of --sampler.vbs.crop-size-width, " + "we can pass override argument as " + "--common.override-kwargs sampler.vbs.crop_size_width=512 \n " + "Note that keys in override arguments do not contain -- or -", + ) + + group.add_argument( + "--common.enable-coreml-compatible-module", + action="store_true", + help="Use coreml compatible modules (if applicable) during inference", + ) + + group.add_argument( + "--common.debug-mode", + action="store_true", + help="You can use this flag for debugging purposes.", + ) + + # intermediate checkpoint related args + group.add_argument( + "--common.save-interval-freq", + type=int, + default=0, + help="Save checkpoints every N updates. Defaults to 0", + ) + + return parser + + +def arguments_ddp(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + title="DDP arguments", description="DDP arguments" + ) + group.add_argument("--ddp.disable", action="store_true", help="Don't use DDP") + group.add_argument( + "--ddp.rank", type=int, default=0, help="Node rank for distributed training" + ) + group.add_argument( + "--ddp.world-size", type=int, default=-1, help="World size for DDP" + ) + group.add_argument("--ddp.dist-url", type=str, default=None, help="DDP URL") + group.add_argument( + "--ddp.dist-port", + type=int, + default=30786, + help="DDP Port. Only used when --ddp.dist-url is not specified", + ) + group.add_argument("--ddp.device-id", type=int, default=None, help="Device ID") + group.add_argument( + "--ddp.no-spawn", action="store_true", help="Don't use DDP with spawn" + ) + group.add_argument( + "--ddp.backend", type=str, default="nccl", help="DDP backend. Default is nccl" + ) + group.add_argument( + "--ddp.find-unused-params", + action="store_true", + help="Find unused params in model. useful for debugging with DDP", + ) + + return parser + + +def parser_to_opts(parser: argparse.ArgumentParser): + # parse args + opts = parser.parse_args() + opts = load_config_file(opts) + return opts + + +def get_training_arguments(parse_args: Optional[bool] = True): + parser = argparse.ArgumentParser(description="Training arguments", add_help=True) + + # cvnet arguments, including models + parser = modeling_arguments(parser=parser) + + # sampler related arguments + parser = arguments_sampler(parser=parser) + + # dataset related arguments + parser = arguments_dataset(parser=parser) + + # Video reader related arguments + parser = arguments_video_reader(parser=parser) + + # collate fn related arguments + parser = arguments_collate_fn(parser=parser) + + # transform related arguments + parser = arguments_augmentation(parser=parser) + + # loss function arguments + parser = arguments_loss_fn(parser=parser) + + # optimizer arguments + parser = arguments_optimizer(parser=parser) + parser = arguments_scheduler(parser=parser) + + # DDP arguments + parser = arguments_ddp(parser=parser) + + # stats arguments + parser = arguments_stats(parser=parser) + + # common + parser = arguments_common(parser=parser) + + # wandb + parser.add_argument('--log-wandb', action='store_true', default=False, + help='log training and validation metrics to wandb') + # parser.set_defaults(log_wandb=True) + parser.add_argument('--experiment', default='debug', type=str, metavar='NAME', + help='name of train experiment, name of sub-folder for output') + + # text tokenizer arguments + parser = arguments_tokenizer(parser=parser) + + if parse_args: + return parser_to_opts(parser) + else: + return parser + + +def get_eval_arguments(parse_args=True): + return get_training_arguments(parse_args=parse_args) + + +def get_conversion_arguments(): + parser = get_training_arguments(parse_args=False) + + # Arguments related to coreml conversion + group = parser.add_argument_group("Conversion arguments") + group.add_argument( + "--conversion.coreml-extn", + type=str, + default="mlmodel", + help="Extension for converted model. Default is mlmodel", + ) + group.add_argument( + "--conversion.input-image-path", + type=str, + default=None, + help="Path of the image to be used for conversion", + ) + + # Arguments related to server. + group.add_argument( + "--conversion.bucket-name", type=str, help="Model job's bucket name" + ) + group.add_argument("--conversion.task-id", type=str, help="Model job's id") + group.add_argument( + "--conversion.viewers", + type=str, + nargs="+", + default=None, + help="Users who can view your models on server", + ) + + # parse args + return parser_to_opts(parser) + + +def get_bencmarking_arguments(): + parser = get_training_arguments(parse_args=False) + + # + group = parser.add_argument_group("Benchmarking arguments") + group.add_argument( + "--benchmark.batch-size", + type=int, + default=1, + help="Batch size for benchmarking", + ) + group.add_argument( + "--benchmark.warmup-iter", type=int, default=10, help="Warm-up iterations" + ) + group.add_argument( + "--benchmark.n-iter", + type=int, + default=100, + help="Number of iterations for benchmarking", + ) + group.add_argument( + "--benchmark.use-jit-model", + action="store_true", + help="Convert the model to JIT and then benchmark it", + ) + + # parse args + return parser_to_opts(parser) + + +def get_segmentation_eval_arguments(): + parser = get_training_arguments(parse_args=False) + + group = parser.add_argument_group("Segmentation evaluation related arguments") + group.add_argument( + "--evaluation.segmentation.apply-color-map", + action="store_true", + help="Apply color map to different classes in segmentation masks. Useful in visualization " + "+ some competitions (e.g, PASCAL VOC) accept submissions with colored segmentation masks", + ) + group.add_argument( + "--evaluation.segmentation.save-overlay-rgb-pred", + action="store_true", + help="enable this flag to visualize predicted masks on top of input image", + ) + group.add_argument( + "--evaluation.segmentation.save-masks", + action="store_true", + help="save predicted masks without colormaps. Useful for submitting to " + "competitions like Cityscapes", + ) + group.add_argument( + "--evaluation.segmentation.overlay-mask-weight", + default=0.5, + type=float, + help="Contribution of mask when overlaying on top of RGB image. ", + ) + group.add_argument( + "--evaluation.segmentation.mode", + type=str, + default="validation_set", + required=False, + choices=["single_image", "image_folder", "validation_set"], + help="Contribution of mask when overlaying on top of RGB image. ", + ) + group.add_argument( + "--evaluation.segmentation.path", + type=str, + default=None, + help="Path of the image or image folder (only required for single_image and image_folder modes)", + ) + group.add_argument( + "--evaluation.segmentation.num-classes", + type=str, + default=None, + help="Number of segmentation classes used during training", + ) + group.add_argument( + "--evaluation.segmentation.resize-input-images", + action="store_true", + help="Resize input images", + ) + + # parse args + return parser_to_opts(parser) + + +def get_detection_eval_arguments(): + parser = get_training_arguments(parse_args=False) + + group = parser.add_argument_group("Detection evaluation related arguments") + group.add_argument( + "--evaluation.detection.save-overlay-boxes", + action="store_true", + help="enable this flag to visualize predicted masks on top of input image", + ) + group.add_argument( + "--evaluation.detection.mode", + type=str, + default="validation_set", + required=False, + choices=["single_image", "image_folder", "validation_set"], + help="Contribution of mask when overlaying on top of RGB image. ", + ) + group.add_argument( + "--evaluation.detection.path", + type=str, + default=None, + help="Path of the image or image folder (only required for single_image and image_folder modes)", + ) + group.add_argument( + "--evaluation.detection.num-classes", + type=str, + default=None, + help="Number of segmentation classes used during training", + ) + group.add_argument( + "--evaluation.detection.resize-input-images", + action="store_true", + default=False, + help="Resize the input images", + ) + + # parse args + return parser_to_opts(parser) + + +def get_loss_landscape_args(): + parser = get_training_arguments(parse_args=False) + + group = parser.add_argument_group("Loss landscape related arguments") + group.add_argument( + "--loss-landscape.n-points", + type=int, + default=11, + help="No. of grid points. Default is 11, so we have 11x11 grid", + ) + group.add_argument( + "--loss-landscape.min-x", + type=float, + default=-1.0, + help="Min. value along x-axis", + ) + group.add_argument( + "--loss-landscape.max-x", + type=float, + default=1.0, + help="Max. value along x-axis", + ) + group.add_argument( + "--loss-landscape.min-y", + type=float, + default=-1.0, + help="Min. value along y-axis", + ) + group.add_argument( + "--loss-landscape.max-y", + type=float, + default=1.0, + help="Max. value along y-axis", + ) + + # parse args + return parser_to_opts(parser) diff --git a/Adaptive Frequency Filters/options/parse_args.py b/Adaptive Frequency Filters/options/parse_args.py new file mode 100644 index 0000000..3d15ac8 --- /dev/null +++ b/Adaptive Frequency Filters/options/parse_args.py @@ -0,0 +1,43 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + + +def parse_validation_metric_names(opts): + """ + This function contains common command-line parsing logic for validation metrics + """ + metric_names = getattr(opts, "stats.val", ["loss"]) + if isinstance(metric_names, str): + metric_names = [metric_names] + assert isinstance( + metric_names, list + ), "Type of metric names should be list. Got: {}".format(type(metric_names)) + + if "loss" not in metric_names: + metric_names.append("loss") + + ckpt_metric_str = getattr(opts, "stats.checkpoint_metric", "loss") + ckpt_metric_arr = ckpt_metric_str.split(".") + ckpt_metric = ckpt_metric_arr[0] + if len(ckpt_metric_arr) == 1: + ckpt_submetric_name = None + else: + ckpt_submetric_name = ckpt_metric_arr[-1] + + ckpt_metric = ckpt_metric + ckpt_submetric = ckpt_submetric_name + if ckpt_metric is None: + # if checkpoint metric is not specified, then use loss + ckpt_metric = "loss" + + assert ( + ckpt_metric in metric_names + ), "Checkpoint metric should be part of metric names. Metric names: {}, Checkpoint metric: {}".format( + metric_names, ckpt_metric + ) + ckpt_metric = ckpt_metric.lower() + + return metric_names, ckpt_metric, ckpt_submetric diff --git a/Adaptive Frequency Filters/options/utils.py b/Adaptive Frequency Filters/options/utils.py new file mode 100644 index 0000000..5cfacf7 --- /dev/null +++ b/Adaptive Frequency Filters/options/utils.py @@ -0,0 +1,119 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import argparse +import collections +import os + +import yaml + +from utils import logger +from utils.ddp_utils import is_master +from utils.download_utils import get_local_path + +try: + # Workaround for DeprecationWarning when importing Collections + collections_abc = collections.abc +except AttributeError: + collections_abc = collections + +DEFAULT_CONFIG_DIR = "config" + + +def flatten_yaml_as_dict(d, parent_key="", sep="."): + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, collections_abc.MutableMapping): + items.extend(flatten_yaml_as_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def load_config_file(opts): + config_file_name = getattr(opts, "common.config_file", None) + if config_file_name is None: + return opts + is_master_node = is_master(opts) + + if is_master_node: + config_file_name = get_local_path(opts=opts, path=config_file_name) + + if not os.path.isfile(config_file_name): + if len(config_file_name.split("/")) == 1: + # loading files from default config folder + new_config_file_name = "{}/{}".format(DEFAULT_CONFIG_DIR, config_file_name) + if not os.path.isfile(new_config_file_name) and is_master_node: + logger.error( + "Configuration file neither exists at {} nor at {}".format( + config_file_name, new_config_file_name + ) + ) + else: + config_file_name = new_config_file_name + else: + # If absolute path of the file is passed + if not os.path.isfile(config_file_name) and is_master_node: + logger.error( + "Configuration file does not exists at {}".format(config_file_name) + ) + + setattr(opts, "common.config_file", config_file_name) + with open(config_file_name, "r") as yaml_file: + try: + cfg = yaml.load(yaml_file, Loader=yaml.FullLoader) + + flat_cfg = flatten_yaml_as_dict(cfg) + for k, v in flat_cfg.items(): + if hasattr(opts, k): + setattr(opts, k, v) + except yaml.YAMLError as exc: + if is_master_node: + logger.error( + "Error while loading config file: {}. Error message: {}".format( + config_file_name, str(exc) + ) + ) + + # override arguments + override_args = getattr(opts, "override_args", None) + if override_args is not None: + for override_k, override_v in override_args.items(): + if hasattr(opts, override_k): + setattr(opts, override_k, override_v) + + return opts + + +def extend_selected_args_with_prefix( + parser: argparse.ArgumentParser, check_string: str, add_prefix: str +) -> argparse.ArgumentParser: + """ + Helper function to add a prefix to certain arguments. + An example use case is distillation, where we want to add --teacher as a prefix to all --model.* arguments + """ + # all arguments are stored as actions + options = parser._actions + + for option in options: + option_strings = option.option_strings + # option strings are stored as a list + for option_string in option_strings: + if option_string.split(".")[0] == check_string: + parser.add_argument( + add_prefix + option.dest.replace("_", "-"), + nargs="?" + if isinstance(option, argparse._StoreTrueAction) + else option.nargs, + const=option.const, + default=option.default, + type=option.type, + choices=option.choices, + help=option.help, + metavar=option.metavar, + ) + return parser diff --git a/Adaptive Frequency Filters/requirements.txt b/Adaptive Frequency Filters/requirements.txt new file mode 100644 index 0000000..3e4146e --- /dev/null +++ b/Adaptive Frequency Filters/requirements.txt @@ -0,0 +1,56 @@ +psutil +ujson +scikit-learn>=0.19.2 +scikit-image + +# requirement for Pytorch, Torchvision, TorchText +torch==1.13.1 +torchvision==0.13.1 +torchtext==0.13.1 # torchtext version needs to be compatible with PyTorch version +complexPyTorch==0.4 +torch-dct==0.1.6 + +numpy==1.23.5 + +# dependency for coremltools +coremltools==6.2 +nvidia-tensorrt==99.0.0 +tensorrt==8.5.3.1 + + +chardet==5.1.0 + +# dependency for MSCOCO dataset +pycocotools + +# dependency for reading and writing images +opencv-contrib-python==4.5.5.64 + +# dependency for cityscape evaluation +cityscapesscripts + +# added as a dependency to reproduce 3rd party models +pytorchvideo + +# PyAV for video decoding +av + +# FVCore for FLOP calculation +fvcore + +# black for reformatting +black + +# testing +pytest + +# torchtext for multi-model learning +ftfy + +# for hdf5 reading +h5py + +# for reading byte data +pybase64 + +click \ No newline at end of file diff --git a/Adaptive Frequency Filters/requirements_docs.txt b/Adaptive Frequency Filters/requirements_docs.txt new file mode 100644 index 0000000..d83db1e --- /dev/null +++ b/Adaptive Frequency Filters/requirements_docs.txt @@ -0,0 +1,5 @@ +# docs +sphinx +sphinx-rtd-theme +sphinx-argparse +myst-parser \ No newline at end of file diff --git a/Adaptive Frequency Filters/utils/__init__.py b/Adaptive Frequency Filters/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Adaptive Frequency Filters/utils/checkpoint_utils.py b/Adaptive Frequency Filters/utils/checkpoint_utils.py new file mode 100644 index 0000000..de7b6ac --- /dev/null +++ b/Adaptive Frequency Filters/utils/checkpoint_utils.py @@ -0,0 +1,314 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import torch +from typing import Optional, Union, Dict +import math +import glob + +from affnet import EMA +from optim import BaseOptim +from utils import logger +from utils.ddp_utils import is_master +from utils.download_utils import get_local_path + +CHECKPOINT_EXTN = "pt" + + +def get_model_state_dict(model): + if isinstance(model, EMA): + return get_model_state_dict(model.ema_model) + else: + return ( + model.module.state_dict() + if hasattr(model, "module") + else model.state_dict() + ) + + +def load_state_dict(model, state_dict): + if hasattr(model, "module"): + model.module.load_state_dict(state_dict) + else: + model.load_state_dict(state_dict) + return model + + +def average_ckpts(ckpt_loc_list: list): + avg_state_dict = dict() + key_count = dict() + key_dtype = dict() + + for c in ckpt_loc_list: + if not os.path.isfile(c): + pass + ckpt_state_dict = torch.load(c, map_location="cpu") + + for k, v in ckpt_state_dict.items(): + if k not in avg_state_dict: + key_dtype[k] = v.dtype + avg_state_dict[k] = v.clone().to(dtype=torch.float64) + key_count[k] = 1 + else: + avg_state_dict[k] += v.to(dtype=torch.float64) + key_count[k] += 1 + + for k, v in avg_state_dict.items(): + avg_state_dict[k] = v.div(key_count[k]).to(dtype=key_dtype[k]) + return avg_state_dict + + +def avg_n_save_k_checkpoints( + model_state, best_metric, k_best_checkpoints, max_ckpt_metric, ckpt_str +): + try: + ckpt_fname = "{}_score_{:.4f}.{}".format(ckpt_str, best_metric, CHECKPOINT_EXTN) + torch.save(model_state, ckpt_fname) + + best_fnames = glob.glob("{}_score_*".format(ckpt_str)) + best_scores = [ + float(f.split("_score_")[-1].replace(".{}".format(CHECKPOINT_EXTN), "")) + for f in best_fnames + ] + + best_scores_keep = [] + if len(best_scores) > k_best_checkpoints: + best_scores = sorted(best_scores) + if not max_ckpt_metric: + best_scores = best_scores[::-1] + best_scores_keep = best_scores[-k_best_checkpoints:] + for k in best_scores: + if k in best_scores_keep: + continue + rm_ckpt = "{}_score_{:.4f}.{}".format(ckpt_str, k, CHECKPOINT_EXTN) + os.remove(rm_ckpt) + logger.log("Deleting checkpoint: {}".format(rm_ckpt)) + # + if len(best_scores_keep) > 1: + avg_fnames = [ + "{}_score_{:.4f}.{}".format(ckpt_str, k, CHECKPOINT_EXTN) + for k in best_scores_keep + ] + logger.log( + "Averaging checkpoints: {}".format( + [f.split("/")[-1] for f in avg_fnames] + ) + ) + # save the average model + avg_model_state = average_ckpts(ckpt_loc_list=avg_fnames) + ckpt_fname = "{}_avg.{}".format(ckpt_str, CHECKPOINT_EXTN) + if avg_model_state: + torch.save(avg_model_state, ckpt_fname) + logger.log("Averaged checkpoint saved at: {}".format(ckpt_fname)) + except Exception as e: + logger.log("Error in k-best-checkpoint") + print(e) + + +def save_interval_checkpoint( + iterations: int, + epoch: int, + model: torch.nn.Module, + optimizer: Union[BaseOptim, torch.optim.Optimizer], + best_metric: float, + save_dir: str, + gradient_scalar: torch.cuda.amp.GradScaler, + not_intermediate_checkpoint: Optional[bool] = False, + *args, + **kwargs +) -> Dict: + model_state = get_model_state_dict(model) + checkpoint = { + "iterations": iterations, + "epoch": epoch, + "model_state_dict": model_state, + "optim_state_dict": optimizer.state_dict(), + "best_metric": best_metric, + "gradient_scalar_state_dict": gradient_scalar.state_dict(), + } + if not not_intermediate_checkpoint: + ckpt_str = "{}/checkpoint".format(save_dir) + ckpt_fname = "{}_{}_{}.{}".format(ckpt_str, epoch, iterations, CHECKPOINT_EXTN) + torch.save(checkpoint, ckpt_fname) + return checkpoint + + +def save_checkpoint( + iterations: int, + epoch: int, + model: torch.nn.Module, + optimizer: Union[BaseOptim, torch.optim.Optimizer], + best_metric: float, + is_best: bool, + save_dir: str, + gradient_scalar: torch.cuda.amp.GradScaler, + model_ema: Optional[torch.nn.Module] = None, + is_ema_best: Optional[bool] = False, + ema_best_metric: Optional[float] = None, + max_ckpt_metric: Optional[bool] = False, + k_best_checkpoints: Optional[int] = -1, + save_all_checkpoints: Optional[bool] = False, + *args, + **kwargs +) -> None: + model_state = get_model_state_dict(model) + + checkpoint = save_interval_checkpoint( + iterations=iterations, + epoch=epoch, + model=model, + optimizer=optimizer, + best_metric=best_metric, + save_dir=save_dir, + gradient_scalar=gradient_scalar, + not_intermediate_checkpoint=True, + ) + ckpt_str = "{}/checkpoint".format(save_dir) + + if is_best: + best_model_fname = "{}_best.{}".format(ckpt_str, CHECKPOINT_EXTN) + if os.path.isfile(best_model_fname): + os.remove(best_model_fname) + torch.save(model_state, best_model_fname) + logger.log( + "Best checkpoint with score {:.2f} saved at {}".format( + best_metric, best_model_fname + ) + ) + + if k_best_checkpoints > 1: + avg_n_save_k_checkpoints( + model_state, best_metric, k_best_checkpoints, max_ckpt_metric, ckpt_str + ) + + if model_ema is not None: + checkpoint["ema_state_dict"] = get_model_state_dict(model_ema) + ema_fname = "{}_ema.{}".format(ckpt_str, CHECKPOINT_EXTN) + torch.save(checkpoint["ema_state_dict"], ema_fname) + if save_all_checkpoints: + ema_fname = "{}_ema_epoch{}.{}".format(ckpt_str, epoch, CHECKPOINT_EXTN) + torch.save(checkpoint["ema_state_dict"], ema_fname) + + if is_ema_best: + ema_best_fname = "{}_ema_best.{}".format(ckpt_str, CHECKPOINT_EXTN) + if os.path.isfile(ema_best_fname): + os.remove(ema_best_fname) + torch.save(checkpoint["ema_state_dict"], ema_best_fname) + logger.log( + "Best EMA checkpoint with score {:.2f} saved at {}".format( + ema_best_metric, ema_best_fname + ) + ) + + if k_best_checkpoints > 1 and ema_best_metric is not None: + avg_n_save_k_checkpoints( + model_state=checkpoint["ema_state_dict"], + best_metric=ema_best_metric, + k_best_checkpoints=k_best_checkpoints, + max_ckpt_metric=max_ckpt_metric, + ckpt_str="{}_ema".format(ckpt_str), + ) + + ckpt_fname = "{}.{}".format(ckpt_str, CHECKPOINT_EXTN) + torch.save(checkpoint, ckpt_fname) + + ckpt_fname = "{}_last.{}".format(ckpt_str, CHECKPOINT_EXTN) + torch.save(model_state, ckpt_fname) + + if save_all_checkpoints: + ckpt_fname = "{}_epoch{}.{}".format(ckpt_str, epoch, CHECKPOINT_EXTN) + torch.save(model_state, ckpt_fname) + + +def load_checkpoint( + opts, + model: torch.nn.Module, + optimizer: Union[BaseOptim, torch.optim.Optimizer], + gradient_scalar: torch.cuda.amp.GradScaler, + model_ema: Optional[torch.nn.Module] = None, +): + resume_loc = getattr(opts, "common.resume", None) + dev_id = getattr(opts, "dev.device_id", None) + device = getattr(opts, "dev.device", torch.device("cpu")) + start_epoch = start_iteration = 0 + best_metric = ( + 0.0 if getattr(opts, "stats.checkpoint_metric_max", False) else math.inf + ) + auto_resume = getattr(opts, "common.auto_resume", False) + exp_dir = getattr(opts, "common.exp_loc", None) + is_master_node = is_master(opts) + if resume_loc is None and auto_resume and exp_dir is not None: + resume_loc = "{}/checkpoint.{}".format(exp_dir, CHECKPOINT_EXTN) + + resume_loc = get_local_path(opts, path=resume_loc) + if resume_loc is not None and os.path.isfile(resume_loc): + if dev_id is None: + checkpoint = torch.load(resume_loc, map_location=device) + else: + checkpoint = torch.load(resume_loc, map_location="cuda:{}".format(dev_id)) + + start_epoch = checkpoint["epoch"] + 1 + start_iteration = checkpoint["iterations"] + 1 + best_metric = checkpoint["best_metric"] + + model = load_state_dict(model, checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optim_state_dict"]) + gradient_scalar.load_state_dict(checkpoint["gradient_scalar_state_dict"]) + + if model_ema is not None and "ema_state_dict" in checkpoint: + model_ema.ema_model = load_state_dict( + model_ema.ema_model, checkpoint["ema_state_dict"] + ) + + if is_master_node: + logger.log("Loaded checkpoint from {}".format(resume_loc)) + logger.log("Resuming training for epoch {}".format(start_epoch)) + else: + if is_master_node: + logger.log("No checkpoint found at '{}'".format(resume_loc)) + return ( + model, + optimizer, + gradient_scalar, + start_epoch, + start_iteration, + best_metric, + model_ema, + ) + + +def load_model_state(opts, model, model_ema=None): + dev_id = getattr(opts, "dev.device_id", None) + device = getattr(opts, "dev.device", torch.device("cpu")) + finetune_loc = getattr(opts, "common.finetune_imagenet1k", None) + finetune_ema_loc = getattr(opts, "common.finetune_ema", None) + + def load_state(path): + path = get_local_path(opts, path=path) + if dev_id is None: + model_state = torch.load(path, map_location=device) + else: + model_state = torch.load(path, map_location="cuda:{}".format(dev_id)) + return model_state + + if finetune_loc is not None and os.path.isfile(finetune_loc): + # load model dict + model = load_state_dict(model, load_state(finetune_loc)) + + # load ema dict + if model_ema is not None and os.path.isfile(finetune_ema_loc): + model_ema = load_state_dict(model, load_state(finetune_ema_loc)) + + return model, model_ema + + +def copy_weights( + model_src: torch.nn.Module, model_tgt: torch.nn.Module +) -> torch.nn.Module: + with torch.no_grad(): + model_state = get_model_state_dict(model=model_src) + return load_state_dict(model=model_tgt, state_dict=model_state) diff --git a/Adaptive Frequency Filters/utils/color_map.py b/Adaptive Frequency Filters/utils/color_map.py new file mode 100644 index 0000000..fa94a74 --- /dev/null +++ b/Adaptive Frequency Filters/utils/color_map.py @@ -0,0 +1,62 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import numpy as np +from typing import Optional, List + + +class Colormap(object): + """ + Generate colormap for visualizing segmentation masks or bounding boxes. + + This is based on the MATLab code in the PASCAL VOC repository: + http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit + """ + + def __init__(self, n: Optional[int] = 256, normalized: Optional[bool] = False): + super(Colormap, self).__init__() + self.n = n + self.normalized = normalized + + @staticmethod + def get_bit_at_idx(val, idx): + return (val & (1 << idx)) != 0 + + def get_color_map(self) -> np.ndarray: + + dtype = "float32" if self.normalized else "uint8" + color_map = np.zeros((self.n, 3), dtype=dtype) + for i in range(self.n): + r = g = b = 0 + c = i + for j in range(8): + r = r | (self.get_bit_at_idx(c, 0) << 7 - j) + g = g | (self.get_bit_at_idx(c, 1) << 7 - j) + b = b | (self.get_bit_at_idx(c, 2) << 7 - j) + c = c >> 3 + + color_map[i] = np.array([r, g, b]) + color_map = color_map / 255 if self.normalized else color_map + return color_map + + def get_box_color_codes(self) -> List: + box_codes = [] + + for i in range(self.n): + r = g = b = 0 + c = i + for j in range(8): + r = r | (self.get_bit_at_idx(c, 0) << 7 - j) + g = g | (self.get_bit_at_idx(c, 1) << 7 - j) + b = b | (self.get_bit_at_idx(c, 2) << 7 - j) + c = c >> 3 + box_codes.append((int(r), int(g), int(b))) + return box_codes + + def get_color_map_list(self) -> List: + cmap = self.get_color_map() + cmap = np.asarray(cmap).flatten() + return list(cmap) diff --git a/Adaptive Frequency Filters/utils/common_utils.py b/Adaptive Frequency Filters/utils/common_utils.py new file mode 100644 index 0000000..829dcb9 --- /dev/null +++ b/Adaptive Frequency Filters/utils/common_utils.py @@ -0,0 +1,128 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import random +from typing import Dict, Optional, List, Tuple, Any + +import numpy as np +import torch +from packaging import version +from torch import Tensor + +from common import MIN_TORCH_VERSION +from affnet.layers import norm_layers_tuple +from utils import logger +from utils.ddp_utils import is_master + + +def check_compatibility() -> None: + curr_torch_version = torch.__version__ + if version.parse(curr_torch_version) < version.parse(MIN_TORCH_VERSION): + logger.error( + "Min. pytorch version required is {}. Got: {}".format( + MIN_TORCH_VERSION, curr_torch_version + ) + ) + + +def check_frozen_norm_layer(model: torch.nn.Module) -> Tuple[bool, int]: + + if hasattr(model, "module"): + model = model.module + + count_norm = 0 + frozen_state = False + for m in model.modules(): + if isinstance(m, norm_layers_tuple): + frozen_state = m.weight.requires_grad + + return frozen_state, count_norm + + +def device_setup(opts): + """Helper function for setting up the device""" + random_seed = getattr(opts, "common.seed", 0) + random.seed(random_seed) + torch.manual_seed(random_seed) + np.random.seed(random_seed) + + is_master_node = is_master(opts) + if is_master_node: + logger.log("Random seeds are set to {}".format(random_seed)) + logger.log("Using PyTorch version {}".format(torch.__version__)) + + n_gpus = torch.cuda.device_count() + if n_gpus == 0: + if is_master_node: + logger.warning("No GPUs available. Using CPU") + device = torch.device("cpu") + n_gpus = 0 + else: + if is_master_node: + logger.log("Available GPUs: {}".format(n_gpus)) + device = torch.device("cuda") + + if torch.backends.cudnn.is_available(): + import torch.backends.cudnn as cudnn + + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + if is_master_node: + logger.log("CUDNN is enabled") + + allow_tf32 = not getattr(opts, "common.disable_tf32", False) + if torch.cuda.is_available(): + # TF32 is enabled by default in PyTorch < 1.12, but disabled in new versions. + # See for details: https://github.com/pytorch/pytorch/issues/67384 + # Disable it using common.disable_tf32 flag + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 + + setattr(opts, "dev.device", device) + setattr(opts, "dev.num_gpus", n_gpus) + + return opts + + +def create_directories(dir_path: str, is_master_node: bool) -> None: + """Helper function to create directories""" + if not os.path.isdir(dir_path): + os.makedirs(dir_path, exist_ok=True) + if is_master_node: + logger.log("Directory created at: {}".format(dir_path)) + else: + if is_master_node: + logger.log("Directory exists at: {}".format(dir_path)) + + +def move_to_device( + opts, + x: Any, + device: Optional[str] = "cpu", + non_blocking: Optional[bool] = True, + *args, + **kwargs +) -> Any: + """Helper function to move data to a device""" + if isinstance(x, Dict): + for k, v in x.items(): + x[k] = move_to_device( + opts=opts, x=v, device=device, non_blocking=non_blocking + ) + + elif isinstance(x, Tensor): + # only tensors can be moved to a device + x = x.to(device=device, non_blocking=non_blocking) + elif isinstance(x, List): + x = [move_to_device(opts, a, device, non_blocking) for a in x] + return x + + +def is_coreml_conversion(opts) -> bool: + if getattr(opts, "common.enable_coreml_compatible_module", False): + return True + return False diff --git a/Adaptive Frequency Filters/utils/ddp_utils.py b/Adaptive Frequency Filters/utils/ddp_utils.py new file mode 100644 index 0000000..9a33928 --- /dev/null +++ b/Adaptive Frequency Filters/utils/ddp_utils.py @@ -0,0 +1,89 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + + +import socket +import torch +import torch.distributed as dist +from typing import Optional + +from utils import logger + + +def is_master(opts) -> bool: + node_rank = getattr(opts, "ddp.rank", 0) + return node_rank == 0 + + +def dist_barrier(): + dist.barrier() + + +def dist_monitored_barrier( + timeout: Optional[float] = None, + wait_all_ranks: Optional[bool] = False, + group: Optional = None, +): + dist.monitored_barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks) + + +def is_start_rank_node(opts) -> bool: + node_rank = getattr(opts, "ddp.rank", 0) + def_rank = getattr(opts, "ddp.start_rank", 0) + return node_rank == def_rank + + +def get_world_size(): + return dist.get_world_size() + + +def get_node_rank(): + return dist.get_rank() + + +def distributed_init(opts) -> int: + ddp_url = getattr(opts, "ddp.dist_url", None) + is_master_node = is_master(opts) + if ddp_url is None: + ddp_port = getattr(opts, "ddp.dist_port", 6006) + hostname = socket.gethostname() + ddp_url = "tcp://{}:{}".format(hostname, ddp_port) + setattr(opts, "ddp.dist_url", ddp_url) + + node_rank = getattr(opts, "ddp.rank", 0) + world_size = getattr(opts, "ddp.world_size", 0) + if torch.distributed.is_initialized(): + logger.warning("DDP is already initialized and cannot be initialize twice!") + else: + logger.info("distributed init (rank {}): {}".format(node_rank, ddp_url)) + + dist_backend = getattr(opts, "ddp.backend", "nccl") # "gloo" + + if dist_backend is None and dist.is_nccl_available(): + dist_backend = "nccl" + if is_master_node: + logger.log( + "Using NCCL as distributed backend with version={}".format( + torch.cuda.nccl.version() + ) + ) + elif dist_backend is None: + dist_backend = "gloo" + + dist.init_process_group( + backend=dist_backend, + init_method=ddp_url, + world_size=world_size, + rank=node_rank, + ) + + # perform a dummy all-reduce to initialize the NCCL communicator + if torch.cuda.is_available(): + dist.all_reduce(torch.zeros(1).cuda()) + + node_rank = torch.distributed.get_rank() + setattr(opts, "ddp.rank", node_rank) + return node_rank diff --git a/Adaptive Frequency Filters/utils/download_utils.py b/Adaptive Frequency Filters/utils/download_utils.py new file mode 100644 index 0000000..8bd01b4 --- /dev/null +++ b/Adaptive Frequency Filters/utils/download_utils.py @@ -0,0 +1,15 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from .download_utils_base import get_basic_local_path + +try: + from internal.utils.blobby_utils import get_local_path_blobby + + get_local_path = get_local_path_blobby + +except ModuleNotFoundError as mnfe: + get_local_path = get_basic_local_path diff --git a/Adaptive Frequency Filters/utils/download_utils_base.py b/Adaptive Frequency Filters/utils/download_utils_base.py new file mode 100644 index 0000000..a52a0f7 --- /dev/null +++ b/Adaptive Frequency Filters/utils/download_utils_base.py @@ -0,0 +1,89 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import os +import copy +import time +import requests + +from common import TMP_CACHE_LOC +from utils.ddp_utils import is_start_rank_node, dist_barrier +from utils import logger + + +def get_basic_local_path( + opts, + path, + cache_loc=TMP_CACHE_LOC, + force_delete=True, + use_start_rank=True, + sync_ranks=True, + *args, + **kwargs +): + """ + If File name is a URL, download to TMP_CACHE_LOC and then return the local path. Otherwise, don't do anything + """ + if ( + path.find("s3://") > -1 + or path.find("http://") > -1 + or path.find("https://") > -1 + ): + url_path = copy.deepcopy(path) + ckpt_name = path.split(os.sep)[-1] + local_path = "{}/{}".format(cache_loc, ckpt_name) + local_path = str(local_path).strip() + + if os.path.isfile(local_path) and force_delete: + # If file exists, remove it and then download again + # This is important because if we are downloading from bolt tasks, then checkpoint names are the same + if use_start_rank: + # remove files from start rank only + if is_start_rank_node(opts): + os.remove(local_path) + else: + while not os.path.isfile(local_path): + time.sleep(1) + continue + else: + # All ranks in DDP can remove the files + os.remove(local_path) + + if not os.path.isfile(local_path): + if not use_start_rank or is_start_rank_node(opts): + _download_file(url_path, local_path) + else: + while os.path.isfile(local_path): + # download file on start rank and let other ranks keep waiting till file is downloaded + # in DDP, download file in all ranks + time.sleep(1) + continue + + if getattr(opts, "ddp.use_distributed", False) and sync_ranks: + # synchronize between processes + dist_barrier() + return local_path + return path + + +def _download_file(url_path: str, dest_loc: str) -> None: + """ + Helper function to download a file with proxy (used when file fails) + """ + response = requests.get(url_path, stream=True) + if response.status_code == 403: + # try with the HTTP/HTTPS proxy from ENV + proxies = { + "https": os.environ.get("HTTPS_PROXY", None), + "http": os.environ.get("HTTP_PROXY", None), + } + response = requests.get(url_path, stream=True, proxies=proxies) + + if response.status_code == 200: + with open(dest_loc, "wb") as f: + f.write(response.raw.read()) + else: + logger.error("Unable to download file {}".format(url_path)) diff --git a/Adaptive Frequency Filters/utils/logger.py b/Adaptive Frequency Filters/utils/logger.py new file mode 100644 index 0000000..8b8b927 --- /dev/null +++ b/Adaptive Frequency Filters/utils/logger.py @@ -0,0 +1,127 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import time +from typing import Optional +import sys +import os + +text_colors = { + "logs": "\033[34m", # 033 is the escape code and 34 is the color code + "info": "\033[32m", + "warning": "\033[33m", + "debug": "\033[93m", + "error": "\033[31m", + "bold": "\033[1m", + "end_color": "\033[0m", + "light_red": "\033[36m", +} + + +def get_curr_time_stamp() -> str: + return time.strftime("%Y-%m-%d %H:%M:%S") + + +def error(message: str) -> None: + time_stamp = get_curr_time_stamp() + error_str = ( + text_colors["error"] + + text_colors["bold"] + + "ERROR " + + text_colors["end_color"] + ) + + # exiting with code -1 does not tell any information about the error (e.g., NaN encountered in the loss). + # For more descriptive error messages, we replace exit(-1) with sys.exit(ERROR_MESSAGE). + # This allows us to handle specific exceptions in the tests. + + # print("{} - {} - {}".format(time_stamp, error_str, message), flush=True) + # print("{} - {} - {}".format(time_stamp, error_str, "Exiting!!!"), flush=True) + # exit(-1) + + sys.exit("{} - {} - {}. Exiting!!!".format(time_stamp, error_str, message)) + + +def color_text(in_text: str) -> str: + return text_colors["light_red"] + in_text + text_colors["end_color"] + + +def log(message: str, end="\n") -> None: + time_stamp = get_curr_time_stamp() + log_str = ( + text_colors["logs"] + text_colors["bold"] + "LOGS " + text_colors["end_color"] + ) + print("{} - {} - {}".format(time_stamp, log_str, message), end=end) + + +def warning(message: str) -> None: + time_stamp = get_curr_time_stamp() + warn_str = ( + text_colors["warning"] + + text_colors["bold"] + + "WARNING" + + text_colors["end_color"] + ) + print("{} - {} - {}".format(time_stamp, warn_str, message)) + + +def info(message: str, print_line: Optional[bool] = False) -> None: + time_stamp = get_curr_time_stamp() + info_str = ( + text_colors["info"] + text_colors["bold"] + "INFO " + text_colors["end_color"] + ) + print("{} - {} - {}".format(time_stamp, info_str, message)) + if print_line: + double_dash_line(dashes=150) + + +def debug(message: str) -> None: + time_stamp = get_curr_time_stamp() + log_str = ( + text_colors["debug"] + + text_colors["bold"] + + "DEBUG " + + text_colors["end_color"] + ) + print("{} - {} - {}".format(time_stamp, log_str, message)) + + +def double_dash_line(dashes: Optional[int] = 75) -> None: + print(text_colors["error"] + "=" * dashes + text_colors["end_color"]) + + +def singe_dash_line(dashes: Optional[int] = 67) -> None: + print("-" * dashes) + + +def print_header(header: str) -> None: + double_dash_line() + print( + text_colors["info"] + + text_colors["bold"] + + "=" * 50 + + str(header) + + text_colors["end_color"] + ) + double_dash_line() + + +def print_header_minor(header: str) -> None: + print( + text_colors["warning"] + + text_colors["bold"] + + "=" * 25 + + str(header) + + text_colors["end_color"] + ) + + +def disable_printing(): + sys.stdout = open(os.devnull, "w") + + +def enable_printing(): + sys.stdout = sys.__stdout__ diff --git a/Adaptive Frequency Filters/utils/math_utils.py b/Adaptive Frequency Filters/utils/math_utils.py new file mode 100644 index 0000000..6c2fe3f --- /dev/null +++ b/Adaptive Frequency Filters/utils/math_utils.py @@ -0,0 +1,37 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +from typing import Union, Optional + + +def make_divisible( + v: Union[float, int], + divisor: Optional[int] = 8, + min_value: Optional[Union[float, int]] = None, +) -> Union[float, int]: + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def bound_fn( + min_val: Union[float, int], max_val: Union[float, int], value: Union[float, int] +) -> Union[float, int]: + return max(min_val, min(max_val, value)) diff --git a/Adaptive Frequency Filters/utils/my_dataset_folder.py b/Adaptive Frequency Filters/utils/my_dataset_folder.py new file mode 100644 index 0000000..be11091 --- /dev/null +++ b/Adaptive Frequency Filters/utils/my_dataset_folder.py @@ -0,0 +1,100 @@ +# -------------------------------------------------------- +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License +# Written by Guoqiang Wei +# -------------------------------------------------------- + +from torchvision.datasets.vision import VisionDataset +from torchvision.datasets.folder import IMG_EXTENSIONS, make_dataset, default_loader + +import os +import os.path +from typing import Any, Callable, Dict, List, Optional, Tuple + + +def make_dataset_with_ann(ann_file, img_prefix, extensions): + images = [] + with open(ann_file, "r") as f: + contents = f.readlines() + for line_str in contents: + path_contents = [c for c in line_str.split('\t')] + im_file_name = path_contents[0] + class_index = int(path_contents[1]) + + assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions + item = (os.path.join(img_prefix, im_file_name), class_index) + + images.append(item) + + return images + + +class CustomDatasetFolder(VisionDataset): + def __init__( + self, + root: str, + loader: Callable[[str], Any], + extensions: Optional[Tuple[str, ...]] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, + map_txt: str = None, + ) -> None: + super(CustomDatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform) + + if os.path.isfile(map_txt): + samples = make_dataset_with_ann(map_txt, root, extensions=extensions) + else: + classes, class_to_idx = self._find_classes(self.root) + samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) + if len(samples) == 0: + msg = "Found 0 files in subfolders of: {}\n".format(self.root) + if extensions is not None: + msg += "Supported extensions are: {}".format(",".join(extensions)) + raise RuntimeError(msg) + + self.loader = loader + self.extensions = extensions + + self.labels = [y_1k for _, y_1k in samples] + self.classes = list(set(self.labels)) + # self.class_to_idx = class_to_idx + self.samples = samples + self.targets = [s[1] for s in samples] + + def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: + classes = [d.name for d in os.scandir(dir) if d.is_dir()] + classes.sort() + class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} + return classes, class_to_idx + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + path, target = self.samples[index] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target + + def __len__(self) -> int: + return len(self.samples) + + +class ImageFolder(CustomDatasetFolder): + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + loader: Callable[[str], Any] = default_loader, + is_valid_file: Optional[Callable[[str], bool]] = None, + map_txt: str = None, + ): + super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, + transform=transform, + target_transform=target_transform, + is_valid_file=is_valid_file, + map_txt=map_txt) + self.imgs = self.samples diff --git a/Adaptive Frequency Filters/utils/pytorch_to_coreml.py b/Adaptive Frequency Filters/utils/pytorch_to_coreml.py new file mode 100644 index 0000000..057a7ef --- /dev/null +++ b/Adaptive Frequency Filters/utils/pytorch_to_coreml.py @@ -0,0 +1,120 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +from torch import Tensor +import coremltools as ct +from typing import Optional, Dict, Tuple, Union +import numpy as np +from PIL import Image +from torchvision.transforms import functional as F + + +from utils.tensor_utils import create_rand_tensor +from torch.utils.mobile_optimizer import optimize_for_mobile + + +def convert_pytorch_to_coreml( + opts, + pytorch_model: torch.nn.Module, + jit_model_only: Optional[bool] = False, + *args, + **kwargs +) -> Dict: + """ + Convert Pytorch model to CoreML + + :param opts: Arguments + :param pytorch_model: Pytorch model that needs to be converted to JIT or CoreML + :param input_tensor: Input tensor, usually a 4-dimensional tensor of shape Batch x 3 x Height x Width + :return: CoreML model or package + """ + + input_image_path = getattr(opts, "conversion.input_image_path", None) + if input_image_path is not None: + input_pil_img = Image.open(input_image_path).convert("RGB") + input_pil_img = F.resize( + img=input_pil_img, size=256, interpolation=F.InterpolationMode.BILINEAR + ) + input_pil_img = F.center_crop(img=input_pil_img, output_size=224) + input_tensor = F.pil_to_tensor(input_pil_img).float() + input_tensor.div_(255.0) + input_tensor = input_tensor.unsqueeze(0) # add dummy batch dimension + else: + input_pil_img = None + input_tensor = create_rand_tensor(opts=opts, device="cpu") + + if pytorch_model.training: + pytorch_model.eval() + + with torch.no_grad(): + pytorch_out = pytorch_model(input_tensor) + + jit_model = torch.jit.trace(pytorch_model, input_tensor) + jit_out = jit_model(input_tensor) + assertion_check(py_out=pytorch_out, jit_out=jit_out) + + jit_model_optimized = optimize_for_mobile(jit_model) + jit_optimzied_out = jit_model_optimized(input_tensor) + assertion_check(py_out=pytorch_out, jit_out=jit_optimzied_out) + + if jit_model_only and torch.cuda.device_count() > 0: + # For inference on GPU + return {"coreml": None, "jit": jit_model, "jit_optimized": None} + elif jit_model_only and torch.cuda.device_count() == 0: + # For inference on CPU + return {"coreml": None, "jit": jit_model_optimized, "jit_optimized": None} + + coreml_model = ct.convert( + model=jit_model, + inputs=[ + ct.ImageType(name="input", shape=input_tensor.shape, scale=1.0 / 255.0) + ], + convert_to="neuralnetwork", # mlprogram + # preprocessing_args={"scale": 1.0/255.0}, + # minimum_deployment_target=ct.target.iOS15, + # compute_precision=ct.precision.FLOAT16 + ) + + if input_pil_img is not None: + out = coreml_model.predict({"input": input_pil_img}) + + return { + "coreml": coreml_model, + "jit": jit_model, + "jit_optimized": jit_model_optimized, + } + + +def assertion_check( + py_out: Union[Tensor, Dict, Tuple], jit_out: Union[Tensor, Dict, Tuple] +) -> None: + if isinstance(py_out, Dict): + assert isinstance(jit_out, Dict) + keys = py_out.keys() + for k in keys: + np.testing.assert_almost_equal( + py_out[k].cpu().numpy(), + jit_out[k].cpu().numpy(), + decimal=3, + verbose=True, + ) + elif isinstance(py_out, Tensor): + assert isinstance(jit_out, Tensor) + np.testing.assert_almost_equal( + py_out.cpu().numpy(), jit_out.cpu().numpy(), decimal=3, verbose=True + ) + elif isinstance(py_out, Tuple): + assert isinstance(jit_out, Tuple) + for x, y in zip(py_out, jit_out): + np.testing.assert_almost_equal( + x.cpu().numpy(), y.cpu().numpy(), decimal=3, verbose=True + ) + + else: + raise NotImplementedError( + "Only Dictionary[Tensors] or Tuple[Tensors] or Tensors are supported as outputs" + ) diff --git a/Adaptive Frequency Filters/utils/tensor_utils.py b/Adaptive Frequency Filters/utils/tensor_utils.py new file mode 100644 index 0000000..0437e7b --- /dev/null +++ b/Adaptive Frequency Filters/utils/tensor_utils.py @@ -0,0 +1,157 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import numpy as np +import torch +from torch import Tensor +from torch import distributed as dist +from typing import Union, Optional, Tuple + +from utils.third_party.ddp_functional_utils import ( + all_gather as all_gather_with_backward, +) +from common import ( + DEFAULT_IMAGE_HEIGHT, + DEFAULT_IMAGE_WIDTH, + DEFAULT_IMAGE_CHANNELS, + DEFAULT_VIDEO_FRAMES, +) + + +def image_size_from_opts(opts) -> Tuple[int, int]: + try: + sampler_name = getattr(opts, "sampler.name", "variable_batch_sampler").lower() + if sampler_name.find("var") > -1: + im_w = getattr(opts, "sampler.vbs.crop_size_width", DEFAULT_IMAGE_WIDTH) + im_h = getattr(opts, "sampler.vbs.crop_size_height", DEFAULT_IMAGE_HEIGHT) + elif sampler_name.find("multi") > -1: + im_w = getattr(opts, "sampler.msc.crop_size_width", DEFAULT_IMAGE_WIDTH) + im_h = getattr(opts, "sampler.msc.crop_size_height", DEFAULT_IMAGE_HEIGHT) + else: + im_w = getattr(opts, "sampler.bs.crop_size_width", DEFAULT_IMAGE_WIDTH) + im_h = getattr(opts, "sampler.bs.crop_size_height", DEFAULT_IMAGE_HEIGHT) + except Exception as e: + im_h = DEFAULT_IMAGE_HEIGHT + im_w = DEFAULT_IMAGE_WIDTH + return im_h, im_w + + +def video_size_from_opts(opts) -> Tuple[int, int, int]: + try: + sampler_name = getattr(opts, "sampler.name", "video_batch_sampler").lower() + if sampler_name.find("var") > -1: + im_w = getattr(opts, "sampler.vbs.crop_size_width", DEFAULT_IMAGE_WIDTH) + im_h = getattr(opts, "sampler.vbs.crop_size_height", DEFAULT_IMAGE_HEIGHT) + n_frames = getattr( + opts, "sampler.vbs.num_frames_per_clip", DEFAULT_IMAGE_HEIGHT + ) + else: + im_w = getattr(opts, "sampler.bs.crop_size_width", DEFAULT_IMAGE_WIDTH) + im_h = getattr(opts, "sampler.bs.crop_size_height", DEFAULT_IMAGE_HEIGHT) + n_frames = getattr( + opts, "sampler.bs.num_frames_per_clip", DEFAULT_IMAGE_HEIGHT + ) + except Exception as e: + im_h = DEFAULT_IMAGE_HEIGHT + im_w = DEFAULT_IMAGE_WIDTH + n_frames = DEFAULT_VIDEO_FRAMES + return im_h, im_w, n_frames + + +def create_rand_tensor( + opts, device: Optional[str] = "cpu", batch_size: Optional[int] = 1 +) -> Tensor: + sampler = getattr(opts, "sampler.name", "batch_sampler") + if sampler.lower().find("video") > -1: + video_stack = getattr(opts, "video_reader.frame_stack_format", "channel_first") + im_h, im_w, n_frames = video_size_from_opts(opts=opts) + if video_stack == "channel_first": + inp_tensor = torch.randint( + low=0, + high=255, + size=(batch_size, DEFAULT_IMAGE_CHANNELS, n_frames, im_h, im_w), + device=device, + ) + else: + inp_tensor = torch.randint( + low=0, + high=255, + size=(batch_size, n_frames, DEFAULT_IMAGE_CHANNELS, im_h, im_w), + device=device, + ) + else: + im_h, im_w = image_size_from_opts(opts=opts) + inp_tensor = torch.randint( + low=0, + high=255, + size=(batch_size, DEFAULT_IMAGE_CHANNELS, im_h, im_w), + device=device, + ) + inp_tensor = inp_tensor.float().div(255.0) + return inp_tensor + + +def reduce_tensor(inp_tensor: torch.Tensor) -> torch.Tensor: + size = dist.get_world_size() if dist.is_initialized() else 1 + inp_tensor_clone = inp_tensor.clone().detach() + # dist_barrier() + dist.all_reduce(inp_tensor_clone, op=dist.ReduceOp.SUM) + inp_tensor_clone /= size + return inp_tensor_clone + + +def reduce_tensor_sum(inp_tensor: torch.Tensor) -> torch.Tensor: + inp_tensor_clone = inp_tensor.clone().detach() + # dist_barrier() + dist.all_reduce(inp_tensor_clone, op=dist.ReduceOp.SUM) + return inp_tensor_clone + + +def all_gather_list(data): + world_size = dist.get_world_size() + data_list = [None] * world_size + # dist_barrier() + dist.all_gather_object(data_list, data) + return data_list + + +def gather_all_features(features: Tensor, dim=0): + return torch.cat(all_gather_with_backward(features), dim=dim) + # world_size = dist.get_world_size() + # gathered_data = [torch.zeros_like(features)] * world_size + # dist.all_gather(gathered_data, features) + # gathered_data = torch.cat(gathered_data, dim=dim) + # return gathered_data + + +def tensor_to_python_float( + inp_tensor: Union[int, float, torch.Tensor], is_distributed: bool +) -> Union[int, float, np.ndarray]: + if is_distributed and isinstance(inp_tensor, torch.Tensor): + inp_tensor = reduce_tensor(inp_tensor=inp_tensor) + + if isinstance(inp_tensor, torch.Tensor) and inp_tensor.numel() > 1: + # For IOU, we get a C-dimensional tensor (C - number of classes) + # so, we convert here to a numpy array + return inp_tensor.cpu().numpy() + elif hasattr(inp_tensor, "item"): + return inp_tensor.item() + elif isinstance(inp_tensor, (int, float)): + return inp_tensor * 1.0 + else: + raise NotImplementedError( + "The data type is not supported yet in tensor_to_python_float function" + ) + + +def to_numpy(img_tensor: torch.Tensor) -> np.ndarray: + # [0, 1] --> [0, 255] + img_tensor = torch.mul(img_tensor, 255.0) + # BCHW --> BHWC + img_tensor = img_tensor.permute(0, 2, 3, 1) + + img_np = img_tensor.byte().cpu().numpy() + return img_np diff --git a/Adaptive Frequency Filters/utils/third_party/__init__.py b/Adaptive Frequency Filters/utils/third_party/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Adaptive Frequency Filters/utils/third_party/ddp_functional_utils.py b/Adaptive Frequency Filters/utils/third_party/ddp_functional_utils.py new file mode 100644 index 0000000..297c5de --- /dev/null +++ b/Adaptive Frequency Filters/utils/third_party/ddp_functional_utils.py @@ -0,0 +1,466 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- + +import torch +import torch.distributed as dist +from torch.autograd import Function + +""" This code is borrowed from PyTorch in order for affnet to be compatbile with versions < 1.12""" + +# The two imports below are not always available depending on the +# USE_DISTRIBUTED compile flag. Make sure they raise import error +# if we're trying to use them. +try: + from torch.distributed import group, ReduceOp +except ModuleNotFoundError as mnfe: + raise ModuleNotFoundError( + "group and ReduceOp are not found. Make sure that you are using PyTorch>=1.12" + ) + + +def broadcast(tensor, src, group=group.WORLD): + """ + Broadcasts the tensor to the whole group. + + ``tensor`` must have the same number of elements in all processes + participating in the collective. + + Arguments: + tensor (Tensor): Data to be sent if ``src`` is the rank of current + process. + src (int): Source rank. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Received tensor from the broadcast op. + + """ + return _Broadcast.apply(src, group, tensor) + + +def gather(tensor, dst=0, group=group.WORLD): + """ + Gathers a list of tensors in a single process. + + Arguments: + tensor (Tensor): Input tensor. + dst (int, optional): Destination rank (default is 0). + group (ProcessGroup, optional): The process group to work on. + + Returns: + tuple[Tensor]: List of appropriately-sized tensors with the gathered data. + """ + return _Gather.apply(dst, group, tensor) + + +def scatter(tensors, src=0, group=group.WORLD): + """ + Scatters a list of tensors to all processes in a group. + + Each process will receive exactly one tensor and store its data in the + ``tensor`` argument. + + Arguments: + tensors (list[Tensor]): List of tensors to scatter on the source rank. + Receivers must pass ``None`. + src (int, optional): Source rank (default is 0). + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output tensor from the scatter operation. + + """ + return _Scatter.apply(src, group, *tensors) + + +def reduce(tensor, dst, op=ReduceOp.SUM, group=group.WORLD): + """ + Reduces the tensor data across all machines. + + Only the process with rank ``dst`` is going to receive the final result. + + Arguments: + tensor (Tensor): Input of the collective. + dst (int): Destination rank. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output of the collective. + + """ + return _Reduce.apply(dst, op, group, tensor) + + +def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=group.WORLD): + """ + Reduces, then scatters a list of tensors to all processes in a group. + + Arguments: + output (Tensor): Output tensor. + input_list (list[Tensor]): List of tensors to reduce and scatter. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output of the collective. + + """ + return _Reduce_Scatter.apply(op, group, output, *input_list) + + +def all_gather(tensor, group=group.WORLD): + """ + Gathers tensors from the whole group in a list. + + Arguments: + tensor (Tensor): Tensor to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. + + Returns: + tuple([Tensor]): Output of the collective. + + """ + return _AllGather.apply(group, tensor) + + +def _all_gather_base(output_tensor, input_tensor, group=group.WORLD): + """ + Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor. + + Args: + output_tensor (Tensor): Output tensor. It should contain + correctly-sized tensors to be used for output of the collective. + input_tensor (Tensor): Tensor to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + Examples: + >>> # All tensors below are of torch.int64 dtype. + >>> # We have 2 process groups, 2 ranks. + >>> # xdoctest: +SKIP("incorrect want text") + >>> output_tensor = torch.zeros(2, dtype=torch.int64) + >>> output_tensor + [tensor([0, 0])] # Rank 0 and 1 + >>> tensor = torch.arange(1, dtype=torch.int64) + 1 + rank + >>> tensor + tensor([1]) # Rank 0 + tensor([2]) # Rank 1 + >>> dist.all_gather_base(output_tensor, tensor) + >>> output_tensor + tensor([1,2]) # Rank 0 + tensor([1,2]) # Rank 1 + + .. warning:: + `_all_gather_base` is experimental and subject to change. + It is the caller's responsibility to ensure the output_tensor + is correctly sized. + + """ + return _AllGatherBase.apply(output_tensor, input_tensor, group) + + +def all_to_all(output_tensor_list, input_tensor_list, group=group.WORLD): + """ + Each process scatters list of input tensors to all processes in a group and + return gathered list of tensors in output list. + + Arguments: + out_tensor_list (list[Tensor]): list of tensors to gather one per rank. + input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. + group (ProcessGroup, optional): The process group to work on. + + Returns: + tuple([Tensor]): Output of the collective. + + """ + return _AlltoAll.apply(group, output_tensor_list, *input_tensor_list) + + +def all_to_all_single( + output, + input, + output_split_sizes=None, + input_split_sizes=None, + group=group.WORLD, +): + """ + Each process splits input tensor and then scatters the split list + to all processes in a group. Then concatenate the received tensors from all + the processes in the group and return single output tensor. + + Arguments: + output (Tensor): Gathered cancatenated output tensor. + input (Tensor): Input tensor to scatter. + output_split_sizes: (list[Int], optional): Output split sizes for dim 0 + if specified None or empty, dim 0 of ``output`` tensor must divide + equally by ``world_size``. + input_split_sizes: (list[Int], optional): Input split sizes for dim 0 + if specified None or empty, dim 0 of ``input`` tensor must divide + equally by ``world_size``. + + Returns: + Tensor: Output of the collective. + + """ + return _AlltoAllSingle.apply( + group, output, output_split_sizes, input_split_sizes, input + ) + + +def all_reduce(tensor, op=ReduceOp.SUM, group=group.WORLD): + """ + Reduces the tensor data across all machines in such a way that all get + the final result. + + After the call the returned tensor is going to be bitwise + identical in all processes. + + Arguments: + tensor (Tensor): Input of the collective. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output of the collective + + """ + return _AllReduce.apply(op, group, tensor) + + +class _Broadcast(Function): + @staticmethod + def forward(ctx, src, group, tensor): + ctx.src = src + ctx.group = group + ctx.rank = dist.get_rank() + # torch.distributed makes all the calls in place + # we allocate new tensors to avoid this + tensor = tensor.clone() + dist.broadcast(tensor, src, group=group) + return tensor + + @staticmethod + def backward(ctx, grad_output): + gx = _Reduce.apply(ctx.src, ReduceOp.SUM, ctx.group, grad_output) + if ctx.src != ctx.rank: + gx.zero_() + return (None, None, gx) + + +class _Gather(Function): + @staticmethod + def forward(ctx, dst, group, tensor): + ctx.dst = dst + ctx.group = group + # Need to create a list of tensors here to do the + # aggregation, get it from the group size + # tensor should be correctly sized for the method + # gathering + tensor_list = [ + torch.zeros_like(tensor) for i in range(dist.get_world_size(group=group)) + ] + + tensor = tensor.contiguous() + if dist.get_rank(group=group) == dst: + dist.gather(tensor, tensor_list, dst, group=group) + else: + dist.gather(tensor, None, dst, group=group) + return tuple(tensor_list) + + @staticmethod + def backward(ctx, *grad_outputs): + return (None, None) + (_Scatter.apply(ctx.dst, ctx.group, *grad_outputs),) + + +class _Scatter(Function): + @staticmethod + def forward(ctx, src, group, *tensors): + ctx.src = src + ctx.group = group + assert all(t.size() == tensors[0].size() for t in tensors) + output = torch.zeros_like(tensors[0]) + if dist.get_rank(group=group) == src: + dist.scatter(output, list(tensors), src, group=group) + else: + dist.scatter(output, None, src, group=group) + return output + + @staticmethod + def backward(ctx, grad_output): + return (None, None) + _Gather.apply(ctx.src, ctx.group, grad_output) + + +class _Reduce(Function): + @staticmethod + def forward(ctx, src, op, group, tensor): + ctx.src = src + ctx.group = group + tensor = tensor.clone() + dist.reduce(tensor, src, op=op, group=group) + return tensor + + @staticmethod + def backward(ctx, grad_output): + return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),) + + +class _Reduce_Scatter(Function): + @staticmethod + def forward(ctx, op, group, tensor, *input_tensor_list): + ctx.group = group + input_tensor_list = tuple(t.contiguous() for t in input_tensor_list) + dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group) + return tensor + + @staticmethod + def backward(ctx, grad_output): + return (None, None, None) + _AllGather.apply(ctx.group, grad_output) + + +class _AllGather(Function): + @staticmethod + def forward(ctx, group, tensor): + ctx.group = group + out_tensor_list = [ + torch.empty_like(tensor) for _ in range(dist.get_world_size(group=group)) + ] + + dist.all_gather(out_tensor_list, tensor.contiguous(), group=group) + return tuple(out_tensor_list) + + @staticmethod + def backward(ctx, *grad_outputs): + if dist.get_backend(group=ctx.group) is dist.Backend.NCCL: + rank = dist.get_rank() + gx = torch.empty_like(grad_outputs[rank]) + _Reduce_Scatter.apply(ReduceOp.SUM, ctx.group, gx, *grad_outputs) + else: + # As many backends doesn't support ReduceScatter, we use AlltoAll with .sum() + # to emulate the ReduceScatter behavior + tensor_list = [torch.empty_like(tensor) for tensor in grad_outputs] + gxs = _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs) + gx = torch.sum(torch.stack(gxs), dim=0) + return (None, gx) + + +class _AllGatherBase(Function): + @staticmethod + def forward(ctx, output_tensor, input_tensor, group): + ctx.group = group + dist._all_gather_base(output_tensor, input_tensor.contiguous(), group=group) + return output_tensor + + @staticmethod + def backward(ctx, grad_output): + if dist.get_backend(group=ctx.group) is dist.Backend.NCCL: + world_size = dist.get_world_size(group=ctx.group) + out_size = list(grad_output.size()) + if out_size[0] % world_size != 0: + raise RuntimeError( + f"Tensor with dimensions: {out_size} does " + f"not have first dimension divisible by world_size: {world_size}" + ) + out_size[0] = out_size[0] // dist.get_world_size(group=ctx.group) + gx = torch.empty( + out_size, device=grad_output.device, dtype=grad_output.dtype + ) + dist._reduce_scatter_base(gx, grad_output, ReduceOp.SUM, ctx.group) + else: + raise RuntimeError("Backend not supported!") + return (None, gx, None) + + +class _AlltoAll(Function): + @staticmethod + def forward(ctx, group, out_tensor_list, *tensors): + ctx.group = group + ctx.input_tensor_size_list = [ + tensors[i].size() for i in range(dist.get_world_size(group=group)) + ] + my_rank = dist.get_rank(group=group) + tensors = tuple(t.contiguous() for t in tensors) + # Implement it on means of scatter/gather, send/recv async operations have issues + if dist.get_backend(group=group) is dist.Backend.GLOO: + for i in range(dist.get_world_size(group=group)): + to_send = None + if i == my_rank: + to_send = list(tensors) + dist.scatter(out_tensor_list[i], to_send, i, group=group) + else: + dist.all_to_all( + out_tensor_list, + list(tensors), + group=group, + ) + return tuple(out_tensor_list) + + @staticmethod + def backward(ctx, *grad_outputs): + tensor_list = [ + torch.empty( + size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype + ) + for size in ctx.input_tensor_size_list + ] + return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs) + + +class _AlltoAllSingle(Function): + @staticmethod + def forward(ctx, group, output, output_split_sizes, input_split_sizes, input): + ctx.group = group + ctx.input_size = input.size() + ctx.output_split_sizes = input_split_sizes + ctx.input_split_sizes = output_split_sizes + dist.all_to_all_single( + output, + input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + ) + return output + + @staticmethod + def backward(ctx, grad_output): + tensor = torch.empty( + ctx.input_size, device=grad_output.device, dtype=grad_output.dtype + ) + return (None, None, None, None) + ( + _AlltoAllSingle.apply( + ctx.group, + tensor, + ctx.output_split_sizes, + ctx.input_split_sizes, + grad_output.contiguous(), + ), + ) + + +class _AllReduce(Function): + @staticmethod + def forward(ctx, op, group, tensor): + ctx.group = group + ctx.op = op + tensor = tensor.clone() + dist.all_reduce(tensor, op=op, group=group) + return tensor + + @staticmethod + def backward(ctx, grad_output): + return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),) diff --git a/Adaptive Frequency Filters/utils/visualization_utils.py b/Adaptive Frequency Filters/utils/visualization_utils.py new file mode 100644 index 0000000..bbad2e0 --- /dev/null +++ b/Adaptive Frequency Filters/utils/visualization_utils.py @@ -0,0 +1,134 @@ +# -------------------------------------------------------- +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License +# Written by Zhipeng Huang +# -------------------------------------------------------- +import random + +from torch import Tensor +import cv2 +import numpy as np +import copy +from typing import Optional, List, Tuple +from matplotlib.colors import hsv_to_rgb + +from utils.color_map import Colormap +from utils import logger + +FONT_SIZE = cv2.FONT_HERSHEY_PLAIN +LABEL_COLOR = [255, 255, 255] +TEXT_THICKNESS = 1 +RECT_BORDER_THICKNESS = 2 + + +def visualize_boxes_xyxy(image: np.ndarray, boxes: np.ndarray) -> np.ndarray: + """Utility function to draw bounding boxes of objects on a given image""" + boxes = boxes.astype(np.int) + + new_image = copy.deepcopy(image) + for box_idx in range(boxes.shape[0]): + coords = boxes[box_idx] + r, g, b = 255, 0, 0 + # top -left corner + start_coord = (coords[0], coords[1]) + # bottom-right corner + end_coord = (coords[2], coords[3]) + cv2.rectangle(new_image, end_coord, start_coord, (r, g, b), thickness=1) + return new_image + + +def create_colored_mask(mask: np.ndarray, num_classes: int, *args, **kwargs) -> np.ndarray: + """Create a colored mask with random colors""" + colored_mask = np.ones((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) + # 0 for background. + random_hue = random.randint(1, num_classes) + + random_mask_color = hsv_to_rgb((random_hue / num_classes, 0.75, 0.75)) + colored_mask[..., :] = [int(c * 255.) for c in random_mask_color] + colored_mask *= mask[..., None] + return colored_mask + + +def draw_bounding_boxes( + image: np.ndarray, + boxes: np.ndarray, + labels: np.ndarray, + scores: np.ndarray, + masks: Optional[np.ndarray] = None, + color_map: Optional = None, + object_names: Optional[List] = None, + is_bgr_format: Optional[bool] = False, + save_path: Optional[str] = None, + num_classes: Optional[int] = 81 +) -> None: + """Utility function to draw bounding boxes of objects along with their labels and score on a given image""" + boxes = boxes.astype(np.int) + + if is_bgr_format: + # convert from BGR to RGB colorspace + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + if color_map is None: + color_map = Colormap().get_box_color_codes() + + if masks is None: + masks = [None] * len(boxes) + + for label, score, coords, mask in zip(labels, scores, boxes, masks): + r, g, b = color_map[label] + c1 = (coords[0], coords[1]) + c2 = (coords[2], coords[3]) + + if mask is not None: + mask = create_colored_mask(mask=mask, num_classes=num_classes) + image = cv2.addWeighted(image, 1.0, mask, 1.0, gamma=0.0) + + cv2.rectangle(image, c1, c2, (r, g, b), thickness=RECT_BORDER_THICKNESS) + if object_names is not None: + label_text = "{label}: {score:.2f}".format( + label=object_names[label], score=score + ) + t_size = cv2.getTextSize(label_text, FONT_SIZE, 1, TEXT_THICKNESS)[0] + new_c2 = c1[0] + t_size[0] + 3, c1[1] + t_size[1] + 4 + + cv2.rectangle(image, c1, new_c2, (r, g, b), -1) + cv2.putText( + image, + label_text, + (c1[0], c1[1] + t_size[1] + 4), + FONT_SIZE, + 1, + LABEL_COLOR, + TEXT_THICKNESS, + ) + + if save_path is not None: + cv2.imwrite(save_path, image) + logger.log("Detection results stored at: {}".format(save_path)) + return image + + +def convert_to_cityscape_format(img: Tensor) -> Tensor: + """Utility to map predicted segmentation labels to cityscapes format""" + img[img == 19] = 255 + img[img == 18] = 33 + img[img == 17] = 32 + img[img == 16] = 31 + img[img == 15] = 28 + img[img == 14] = 27 + img[img == 13] = 26 + img[img == 12] = 25 + img[img == 11] = 24 + img[img == 10] = 23 + img[img == 9] = 22 + img[img == 8] = 21 + img[img == 7] = 20 + img[img == 6] = 19 + img[img == 5] = 17 + img[img == 4] = 13 + img[img == 3] = 12 + img[img == 2] = 11 + img[img == 1] = 8 + img[img == 0] = 7 + img[img == 255] = 0 + return img