Skip to content

Commit

Permalink
CBNet-EVA
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiwei committed Oct 18, 2024
1 parent 7233410 commit 9d1e164
Show file tree
Hide file tree
Showing 125 changed files with 10,815 additions and 121 deletions.
95 changes: 95 additions & 0 deletions EVA/EVA-02/det/configs/common/models/cb_mask_rcnn_fpn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from detectron2.config import LazyCall as L
from detectron2.layers import ShapeSpec
from detectron2.modeling.meta_arch import CBGeneralizedRCNN
from detectron2.modeling.anchor_generator import DefaultAnchorGenerator
from detectron2.modeling.backbone.fpn import LastLevelMaxPool
from detectron2.modeling.backbone import BasicStem, FPN, ResNet
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.modeling.matcher import Matcher
from detectron2.modeling.poolers import ROIPooler
from detectron2.modeling.proposal_generator import RPN, StandardRPNHead
from detectron2.modeling.roi_heads import (
StandardROIHeads,
FastRCNNOutputLayers,
MaskRCNNConvUpsampleHead,
FastRCNNConvFCHead,
)

from ..data.constants import constants

model = L(CBGeneralizedRCNN)(
backbone=L(FPN)(
bottom_up=L(ResNet)(
stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"),
stages=L(ResNet.make_default_stages)(
depth=50,
stride_in_1x1=True,
norm="FrozenBN",
),
out_features=["res2", "res3", "res4", "res5"],
),
in_features="${.bottom_up.out_features}",
out_channels=256,
top_block=L(LastLevelMaxPool)(),
),
proposal_generator=L(RPN)(
in_features=["p2", "p3", "p4", "p5", "p6"],
head=L(StandardRPNHead)(in_channels=256, num_anchors=3),
anchor_generator=L(DefaultAnchorGenerator)(
sizes=[[32], [64], [128], [256], [512]],
aspect_ratios=[0.5, 1.0, 2.0],
strides=[4, 8, 16, 32, 64],
offset=0.0,
),
anchor_matcher=L(Matcher)(
thresholds=[0.3, 0.7], labels=[0, -1, 1], allow_low_quality_matches=True
),
box2box_transform=L(Box2BoxTransform)(weights=[1.0, 1.0, 1.0, 1.0]),
batch_size_per_image=256,
positive_fraction=0.5,
pre_nms_topk=(2000, 1000),
post_nms_topk=(1000, 1000),
nms_thresh=0.7,
),
roi_heads=L(StandardROIHeads)(
num_classes=80,
batch_size_per_image=512,
positive_fraction=0.25,
proposal_matcher=L(Matcher)(
thresholds=[0.5], labels=[0, 1], allow_low_quality_matches=False
),
box_in_features=["p2", "p3", "p4", "p5"],
box_pooler=L(ROIPooler)(
output_size=7,
scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32),
sampling_ratio=0,
pooler_type="ROIAlignV2",
),
box_head=L(FastRCNNConvFCHead)(
input_shape=ShapeSpec(channels=256, height=7, width=7),
conv_dims=[],
fc_dims=[1024, 1024],
),
box_predictor=L(FastRCNNOutputLayers)(
input_shape=ShapeSpec(channels=1024),
test_score_thresh=0.05,
box2box_transform=L(Box2BoxTransform)(weights=(10, 10, 5, 5)),
num_classes="${..num_classes}",
),
mask_in_features=["p2", "p3", "p4", "p5"],
mask_pooler=L(ROIPooler)(
output_size=14,
scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32),
sampling_ratio=0,
pooler_type="ROIAlignV2",
),
mask_head=L(MaskRCNNConvUpsampleHead)(
input_shape=ShapeSpec(channels=256, width=14, height=14),
num_classes="${..num_classes}",
conv_dims=[256, 256, 256, 256, 256],
),
),
pixel_mean=constants.imagenet_bgr256_mean,
pixel_std=constants.imagenet_bgr256_std,
input_format="BGR",
)
87 changes: 87 additions & 0 deletions EVA/EVA-02/det/configs/common/models/cb_mask_rcnn_vitdet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from functools import partial
import torch.nn as nn
from detectron2.config import LazyCall as L
from detectron2.modeling import CBViT, CBSimpleFeaturePyramid
from detectron2.modeling.backbone.fpn import LastLevelMaxPool

from .cb_mask_rcnn_fpn import model
from ..data.constants import constants

model.pixel_mean = constants.imagenet_rgb256_mean
model.pixel_std = constants.imagenet_rgb256_std
model.input_format = "RGB"

# from apex.normalization import FusedLayerNorm

# Base
embed_dim, depth, num_heads, dp = 768, 12, 12, 0.1
# Creates Simple Feature Pyramid from ViT backbone
model.backbone = L(CBSimpleFeaturePyramid)(
net=L(CBViT)( # Single-scale ViT backbone
img_size=1024,
patch_size=16,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
drop_path_rate=dp,
window_size=14,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
window_block_indexes=[
# 2, 5, 8 11 for global attention
0,
1,
3,
4,
6,
7,
9,
10,
],
residual_block_indexes=[],
use_rel_pos=True,
out_feature="last_feat",
),
cb_net=L(CBViT)( # Single-scale ViT backbone
img_size=1024,
patch_size=16,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
drop_path_rate=dp,
window_size=14,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
window_block_indexes=[
# 2, 5, 8 11 for global attention
0,
1,
3,
4,
6,
7,
9,
10,
],
residual_block_indexes=[],
use_rel_pos=True,
out_feature="last_feat",
),
in_feature="${.net.out_feature}",
out_channels=256,
scale_factors=(4.0, 2.0, 1.0, 0.5),
top_block=L(LastLevelMaxPool)(),
norm="LN",
square_pad=1024,
)

model.roi_heads.box_head.conv_norm = model.roi_heads.mask_head.conv_norm = "LN"

# 2conv in RPN:
model.proposal_generator.head.conv_dims = [-1, -1]

# 4conv1fc box head
model.roi_heads.box_head.conv_dims = [256, 256, 256, 256]
model.roi_heads.box_head.fc_dims = [1024]
21 changes: 21 additions & 0 deletions EVA/EVA-02/det/detectron2/modeling/backbone/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from .build import build_backbone, BACKBONE_REGISTRY # noqa F401 isort:skip

from .backbone import Backbone
from .fpn import FPN
from .regnet import RegNet
from .resnet import (
BasicStem,
ResNet,
ResNetBlockBase,
build_resnet_backbone,
make_stage,
BottleneckBlock,
)
from .vit import ViT, SimpleFeaturePyramid, get_vit_lr_decay_rate
from .mvit import MViT
from .swin import SwinTransformer
from .cb_vit import CBViT, CBSimpleFeaturePyramid

__all__ = [k for k in globals().keys() if not k.startswith("_")]
# TODO can expose more resnet blocks after careful consideration
74 changes: 74 additions & 0 deletions EVA/EVA-02/det/detectron2/modeling/backbone/backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from abc import ABCMeta, abstractmethod
from typing import Dict
import torch.nn as nn

from detectron2.layers import ShapeSpec

__all__ = ["Backbone"]


class Backbone(nn.Module, metaclass=ABCMeta):
"""
Abstract base class for network backbones.
"""

def __init__(self):
"""
The `__init__` method of any subclass can specify its own set of arguments.
"""
super().__init__()

@abstractmethod
def forward(self):
"""
Subclasses must override this method, but adhere to the same return type.
Returns:
dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor
"""
pass

@property
def size_divisibility(self) -> int:
"""
Some backbones require the input height and width to be divisible by a
specific integer. This is typically true for encoder / decoder type networks
with lateral connection (e.g., FPN) for which feature maps need to match
dimension in the "bottom up" and "top down" paths. Set to 0 if no specific
input size divisibility is required.
"""
return 0

@property
def padding_constraints(self) -> Dict[str, int]:
"""
This property is a generalization of size_divisibility. Some backbones and training
recipes require specific padding constraints, such as enforcing divisibility by a specific
integer (e.g., FPN) or padding to a square (e.g., ViTDet with large-scale jitter
in :paper:vitdet). `padding_constraints` contains these optional items like:
{
"size_divisibility": int,
"square_size": int,
# Future options are possible
}
`size_divisibility` will read from here if presented and `square_size` indicates the
square padding size if `square_size` > 0.
TODO: use type of Dict[str, int] to avoid torchscipt issues. The type of padding_constraints
could be generalized as TypedDict (Python 3.8+) to support more types in the future.
"""
return {}

def output_shape(self):
"""
Returns:
dict[str->ShapeSpec]
"""
# this is a backward-compatible default
return {
name: ShapeSpec(
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
)
for name in self._out_features
}
33 changes: 33 additions & 0 deletions EVA/EVA-02/det/detectron2/modeling/backbone/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from detectron2.layers import ShapeSpec
from detectron2.utils.registry import Registry

from .backbone import Backbone

BACKBONE_REGISTRY = Registry("BACKBONE")
BACKBONE_REGISTRY.__doc__ = """
Registry for backbones, which extract feature maps from images
The registered object must be a callable that accepts two arguments:
1. A :class:`detectron2.config.CfgNode`
2. A :class:`detectron2.layers.ShapeSpec`, which contains the input shape specification.
Registered object must return instance of :class:`Backbone`.
"""


def build_backbone(cfg, input_shape=None):
"""
Build a backbone from `cfg.MODEL.BACKBONE.NAME`.
Returns:
an instance of :class:`Backbone`
"""
if input_shape is None:
input_shape = ShapeSpec(channels=len(cfg.MODEL.PIXEL_MEAN))

backbone_name = cfg.MODEL.BACKBONE.NAME
backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg, input_shape)
assert isinstance(backbone, Backbone)
return backbone
Loading

0 comments on commit 9d1e164

Please sign in to comment.