-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
241 changed files
with
40,286 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# -------------------------------------------------------- | ||
# Copyright (c) 2023 Microsoft | ||
# Licensed under The MIT License | ||
# Written by Zhipeng Huang | ||
# -------------------------------------------------------- | ||
|
||
import argparse | ||
|
||
|
||
from options.utils import extend_selected_args_with_prefix | ||
from affnet.misc.common import parameter_list | ||
from affnet.anchor_generator import arguments_anchor_gen | ||
from affnet.image_projection_layers import arguments_image_projection_head | ||
from affnet.layers import arguments_nn_layers | ||
from affnet.matcher_det import arguments_box_matcher | ||
from affnet.misc.averaging_utils import arguments_ema, EMA | ||
from affnet.misc.profiler import module_profile | ||
from affnet.models import arguments_model, get_model | ||
from affnet.models.detection.base_detection import DetectionPredTuple | ||
from affnet.neural_augmentor import arguments_neural_augmentor | ||
from affnet.text_encoders import arguments_text_encoder | ||
|
||
|
||
def modeling_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: | ||
# model arguments | ||
parser = arguments_model(parser) | ||
# neural network layer argumetns | ||
parser = arguments_nn_layers(parser) | ||
# EMA arguments | ||
parser = arguments_ema(parser) | ||
# anchor generator arguments (for object detection) | ||
parser = arguments_anchor_gen(parser) | ||
# box matcher arguments (for object detection) | ||
parser = arguments_box_matcher(parser) | ||
# text encoder arguments (usually for multi-modal tasks) | ||
parser = arguments_text_encoder(parser) | ||
# image projection head arguments (usually for multi-modal tasks) | ||
parser = arguments_image_projection_head(parser) | ||
# neural aug arguments | ||
parser = arguments_neural_augmentor(parser) | ||
|
||
# Add teacher as a prefix to enable distillation tasks | ||
# keep it as the last entry | ||
parser = extend_selected_args_with_prefix( | ||
parser, check_string="--model", add_prefix="--teacher." | ||
) | ||
|
||
return parser |
86 changes: 86 additions & 0 deletions
86
Adaptive Frequency Filters/affnet/anchor_generator/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# -------------------------------------------------------- | ||
# Copyright (c) 2023 Microsoft | ||
# Licensed under The MIT License | ||
# Written by Zhipeng Huang | ||
# -------------------------------------------------------- | ||
|
||
import argparse | ||
import os | ||
import importlib | ||
|
||
from utils import logger | ||
from utils.ddp_utils import is_master | ||
|
||
from .base_anchor_generator import BaseAnchorGenerator | ||
|
||
# register anchor generator | ||
ANCHOR_GEN_REGISTRY = {} | ||
|
||
|
||
def register_anchor_generator(name): | ||
"""Register anchor generators for object detection""" | ||
|
||
def register_class(cls): | ||
if name in ANCHOR_GEN_REGISTRY: | ||
raise ValueError( | ||
"Cannot register duplicate anchor generator ({})".format(name) | ||
) | ||
|
||
if not issubclass(cls, BaseAnchorGenerator): | ||
raise ValueError( | ||
"Anchor generator ({}: {}) must extend BaseAnchorGenerator".format( | ||
name, cls.__name__ | ||
) | ||
) | ||
|
||
ANCHOR_GEN_REGISTRY[name] = cls | ||
return cls | ||
|
||
return register_class | ||
|
||
|
||
def arguments_anchor_gen(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: | ||
"""Arguments related to anchor generator for object detection""" | ||
group = parser.add_argument_group("Anchor generator", "Anchor generator") | ||
group.add_argument( | ||
"--anchor-generator.name", type=str, help="Name of the anchor generator" | ||
) | ||
|
||
for k, v in ANCHOR_GEN_REGISTRY.items(): | ||
parser = v.add_arguments(parser=parser) | ||
|
||
return parser | ||
|
||
|
||
def build_anchor_generator(opts, *args, **kwargs): | ||
"""Build anchor generator for object detection""" | ||
anchor_gen_name = getattr(opts, "anchor_generator.name", None) | ||
anchor_gen = None | ||
if anchor_gen_name in ANCHOR_GEN_REGISTRY: | ||
anchor_gen = ANCHOR_GEN_REGISTRY[anchor_gen_name](opts, *args, **kwargs) | ||
else: | ||
supported_anchor_gens = list(ANCHOR_GEN_REGISTRY.keys()) | ||
supp_anchor_gen_str = ( | ||
"Got {} as anchor generator. Supported anchor generators are:".format( | ||
anchor_gen_name | ||
) | ||
) | ||
for i, m_name in enumerate(supported_anchor_gens): | ||
supp_anchor_gen_str += "\n\t {}: {}".format(i, logger.color_text(m_name)) | ||
|
||
if is_master(opts): | ||
logger.error(supp_anchor_gen_str) | ||
return anchor_gen | ||
|
||
|
||
# automatically import the anchor generators | ||
anchor_gen_dir = os.path.dirname(__file__) | ||
for file in os.listdir(anchor_gen_dir): | ||
path = os.path.join(anchor_gen_dir, file) | ||
if ( | ||
not file.startswith("_") | ||
and not file.startswith(".") | ||
and (file.endswith(".py") or os.path.isdir(path)) | ||
): | ||
anc_gen = file[: file.find(".py")] if file.endswith(".py") else file | ||
module = importlib.import_module("affnet.anchor_generator." + anc_gen) |
99 changes: 99 additions & 0 deletions
99
Adaptive Frequency Filters/affnet/anchor_generator/base_anchor_generator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# -------------------------------------------------------- | ||
# Copyright (c) 2023 Microsoft | ||
# Licensed under The MIT License | ||
# Written by Zhipeng Huang | ||
# -------------------------------------------------------- | ||
|
||
import torch | ||
from torch import Tensor | ||
import argparse | ||
from typing import Optional, Tuple, Union | ||
|
||
|
||
class BaseAnchorGenerator(torch.nn.Module): | ||
""" | ||
Base class for anchor generators for the task of object detection. | ||
""" | ||
|
||
def __init__(self, *args, **kwargs) -> None: | ||
super().__init__() | ||
self.anchors_dict = dict() | ||
|
||
@classmethod | ||
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: | ||
""" | ||
Add anchor generator-specific arguments to the parser | ||
""" | ||
return parser | ||
|
||
def num_anchors_per_os(self): | ||
"""Returns anchors per output stride. Child classes must implement this function.""" | ||
raise NotImplementedError | ||
|
||
@torch.no_grad() | ||
def _generate_anchors( | ||
self, | ||
height: int, | ||
width: int, | ||
output_stride: int, | ||
device: Optional[str] = "cpu", | ||
*args, | ||
**kwargs | ||
) -> Union[Tensor, Tuple[Tensor, ...]]: | ||
raise NotImplementedError | ||
|
||
@torch.no_grad() | ||
def _get_anchors( | ||
self, | ||
fm_height: int, | ||
fm_width: int, | ||
fm_output_stride: int, | ||
device: Optional[str] = "cpu", | ||
*args, | ||
**kwargs | ||
) -> Union[Tensor, Tuple[Tensor, ...]]: | ||
key = "h_{}_w_{}_os_{}".format(fm_height, fm_width, fm_output_stride) | ||
if key not in self.anchors_dict: | ||
default_anchors_ctr = self._generate_anchors( | ||
height=fm_height, | ||
width=fm_width, | ||
output_stride=fm_output_stride, | ||
device=device, | ||
*args, | ||
**kwargs | ||
) | ||
self.anchors_dict[key] = default_anchors_ctr | ||
return default_anchors_ctr | ||
else: | ||
return self.anchors_dict[key] | ||
|
||
@torch.no_grad() | ||
def forward( | ||
self, | ||
fm_height: int, | ||
fm_width: int, | ||
fm_output_stride: int, | ||
device: Optional[str] = "cpu", | ||
*args, | ||
**kwargs | ||
) -> Union[Tensor, Tuple[Tensor, ...]]: | ||
""" | ||
Returns anchors for the feature map | ||
Args: | ||
fm_height (int): Height of the feature map | ||
fm_width (int): Width of the feature map | ||
fm_output_stride (int): Output stride of the feature map | ||
device (Optional, str): Device (cpu or cuda). Defaults to cpu | ||
Returns: | ||
Tensor or Tuple of Tensors | ||
""" | ||
return self._get_anchors( | ||
fm_height=fm_height, | ||
fm_width=fm_width, | ||
fm_output_stride=fm_output_stride, | ||
device=device, | ||
*args, | ||
**kwargs | ||
) |
Oops, something went wrong.