diff --git a/.gitignore b/.gitignore index 1d0337c..d54c776 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ __pycache__ .ipynb_checkpoints cam_images/ .vscode +ut.py diff --git a/configs/api-net/api-net-resnet101.yml b/configs/api_net/api_net_resnet101.yml similarity index 95% rename from configs/api-net/api-net-resnet101.yml rename to configs/api_net/api_net_resnet101.yml index c41eb20..c0320b4 100644 --- a/configs/api-net/api-net-resnet101.yml +++ b/configs/api_net/api_net_resnet101.yml @@ -114,8 +114,11 @@ TRANSFORMS: OPTIMIZER: NAME: "SGD" - MOMENTUM: 0.9 + ARGS: + - momentum: 0.9 + - weight_decay: 0.0005 LR: + base: 0.002 backbone: 0.002 encoder: 0.01 necks: 0.01 @@ -124,7 +127,7 @@ OPTIMIZER: ITERATION_NUM: ~ EPOCH_NUM: 100 START_EPOCH: 0 -UPDATE_STRATEGY: "general_updating" +UPDATE_STRATEGY: "general_strategy" # Validation details diff --git a/configs/cal/cal_resnet101.yml b/configs/cal/cal_resnet101.yml index 05fce78..eaa8824 100644 --- a/configs/cal/cal_resnet101.yml +++ b/configs/cal/cal_resnet101.yml @@ -105,9 +105,11 @@ TRANSFORMS: OPTIMIZER: NAME: SGD - MOMENTUM: 0.9 - WEIGHT_DECAY: 1e-5 + ARGS: + - momentum: 0.9 + - weight_decay: 0.0005 LR: + base: 0.001 backbone: 0.001 encoder: 0.001 necks: 0.001 @@ -116,14 +118,14 @@ OPTIMIZER: ITERATION_NUM: ~ EPOCH_NUM: 160 START_EPOCH: 0 -UPDATE_STRATEGY: "general_updating" +UPDATE_STRATEGY: "general_strategy" LR_SCHEDULE: NAME: "adjusting_schedule" ARGS: - base_lr: 0.001 - base_duration: 2.0 - base_rate: 0.9 - - update_level: "batch_update" + # - update_level: "batch_update" # Validation details PER_ITERATION: ~ diff --git a/configs/mutual_channel_loss/mcl_vgg16.yml b/configs/mutual_channel_loss/mcl_vgg16.yml index 675c961..9f5492c 100644 --- a/configs/mutual_channel_loss/mcl_vgg16.yml +++ b/configs/mutual_channel_loss/mcl_vgg16.yml @@ -102,8 +102,11 @@ TRANSFORMS: OPTIMIZER: NAME: SGD - MOMENTUM: 0.9 + ARGS: + - momentum: 0.9 + - weight_decay: 0.0005 LR: + base: 0.001 backbone: 0.001 encoder: 0.01 necks: 0.01 @@ -112,7 +115,7 @@ OPTIMIZER: ITERATION_NUM: ~ EPOCH_NUM: 200 START_EPOCH: 0 -UPDATE_STRATEGY: "general_updating" +UPDATE_STRATEGY: "general_strategy" # Validation details PER_ITERATION: ~ @@ -134,9 +137,4 @@ METRICS: metric: "precision" top_k: ~ threshold: 0.5 - -INTERPRETER: - NAME: "cam" - METHOD: "gradcam" - TARGET_LAYERS: - - "layer4" \ No newline at end of file + \ No newline at end of file diff --git a/configs/progressive_multi_granularity_learning/pmg_resnet50.yml b/configs/progressive_multi_granularity_learning/pmg_resnet50.yml index f49bada..f0ae46a 100644 --- a/configs/progressive_multi_granularity_learning/pmg_resnet50.yml +++ b/configs/progressive_multi_granularity_learning/pmg_resnet50.yml @@ -106,8 +106,11 @@ TRANSFORMS: OPTIMIZER: NAME: "SGD" - MOMENTUM: 0.9 + ARGS: + - momentum: 0.9 + - weight_decay: 0.0005 LR: + base: 0.0002 backbone: 0.0002 encoder: 0.002 necks: 0.002 @@ -140,8 +143,3 @@ METRICS: top_k: ~ threshold: 0.5 -INTERPRETER: - NAME: "cam" - METHOD: "gradcam" - TARGET_LAYERS: - - "layer4" diff --git a/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml b/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml index 1b76992..c9b5635 100644 --- a/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml +++ b/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml @@ -28,6 +28,16 @@ DATASET: MODEL: NAME: "PMG_V2" CLASS_NUM: 200 + ARGS: + - outputs_num: 3 + - BLOCKS: + - [8, 8, 0, 0] + - [4, 4, 4, 0] + - [2, 2, 2, 2] + - alpha: + - 0.01 + - 0.05 + - 0.1 CRITERIONS: - name: "cross_entropy_loss" args: [] @@ -108,8 +118,11 @@ TRANSFORMS: OPTIMIZER: NAME: "SGD" - MOMENTUM: 0.9 + ARGS: + - momentum: 0.9 + - weight_decay: 0.0005 LR: + base: 0.0005 backbone: 0.0005 encoder: ~ necks: 0.005 @@ -140,9 +153,3 @@ METRICS: metric: "precision" top_k: ~ threshold: 0.5 - -INTERPRETER: - NAME: "cam" - METHOD: "gradcam" - TARGET_LAYERS: - - "layer4" \ No newline at end of file diff --git a/configs/resnet/resnet50.yml b/configs/resnet/resnet50.yml index 7c52d2a..717f113 100644 --- a/configs/resnet/resnet50.yml +++ b/configs/resnet/resnet50.yml @@ -88,17 +88,20 @@ TRANSFORMS: OPTIMIZER: NAME: "SGD" - MOMENTUM: 0.9 - LR: + ARGS: + - momentum: 0.9 + - weight_decay: 0.0005 + LR: + base: 0.0002 backbone: 0.0002 encoder: 0.002 necks: 0.002 heads: 0.002 ITERATION_NUM: ~ -EPOCH_NUM: 1 +EPOCH_NUM: 50 START_EPOCH: 0 -UPDATE_STRATEGY: "general_updating" +UPDATE_STRATEGY: "general_strategy" # Validation details PER_ITERATION: ~ diff --git a/configs/resnet/resnet50_cutmix.yml b/configs/resnet/resnet50_cutmix.yml index 86f517f..0930d5d 100644 --- a/configs/resnet/resnet50_cutmix.yml +++ b/configs/resnet/resnet50_cutmix.yml @@ -88,8 +88,11 @@ TRANSFORMS: OPTIMIZER: NAME: "SGD" - MOMENTUM: 0.9 - LR: + ARGS: + - momentum: 0.9 + - weight_decay: 0.0005 + LR: + base: 0.0002 backbone: 0.0002 encoder: 0.002 necks: 0.002 diff --git a/configs/swin_transformer/swin_transformer.yml b/configs/swin_transformer/swin_transformer.yml new file mode 100644 index 0000000..bd228cc --- /dev/null +++ b/configs/swin_transformer/swin_transformer.yml @@ -0,0 +1,168 @@ +EXP_NAME: "SwinT" + +RESUME_WEIGHT: ~ + +DISTRIBUTED: False + +WEIGHT: + NAME: "swinT.pth" + SAVE_DIR: "/mnt/sdb/data/wangxinran/weight/fgvclib" + +LOGGER: + NAME: "txt_logger" + +DATASET: + NAME: "CUB_200_2011" + ROOT: "/mnt/sdb/data/wangxinran/dataset" + TRAIN: + BATCH_SIZE: 16 + POSITIVE: 0 + PIN_MEMORY: True + SHUFFLE: True + NUM_WORKERS: 4 + TEST: + BATCH_SIZE: 16 + POSITIVE: 0 + PIN_MEMORY: False + SHUFFLE: True + NUM_WORKERS: 4 + +MODEL: + NAME: "SwinTransformer" + CLASS_NUM: 200 + ARGS: + - img_size: 384 + - fpn_size: 1536 + - lambda_s: 0.0 + - lambda_n: 5.0 + - lambda_b: 0.5 + - lambda_c: 1.0 + - update_freq: 2 + - use_selection: True + - use_fpn: True + - use_combiner: True + - num_select: + - layer1: 2048 + - layer2: 512 + - layer3: 128 + - layer4: 32 + + CRITERIONS: + - name: "cross_entropy_loss" + args: [] + w: 1.0 + - name: "mean_square_error_loss" + args: [] + w: 1.0 + BACKBONE: + NAME: "swin_large_patch4_window12_384_in22k" + ARGS: + - pretrained: True + ENCODER: + NAME: ~ + NECKS: + NAME: ~ + HEADS: + NAME: "GCN_combiner" + ARGS: + - num_selects: + - layer1: 2048 + - layer2: 512 + - layer3: 128 + - layer4: 32 + - total_num_selects: 2720 + - num_classes: 200 + - fpn_size: 1536 + +TRANSFORMS: + TRAIN: + - name: "resize" + size: + - 510 + - 510 + - name: "random_crop" + size: 384 + padding: 0 + - name: "random_horizontal_flip" + prob: 0.5 + - name: "randomApply_gaussianBlur" + prob: 0.1 + - name: "randomAdjust_sharpness" + sharpness_factor: 1.5 + prob: 0.1 + - name: "to_tensor" + - name: "normalize" + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + TEST: + - name: "resize" + size: + - 510 + - 510 + - name: "center_crop" + size: 384 + - name: "to_tensor" + - name: "normalize" + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + +OPTIMIZER: + NAME: "AdamW" + ARGS: + - weight_decay: 0.0005 + LR: + base: 0.0001 + backbone: 0.0001 + encoder: 0.0001 + necks: 0.0001 + heads: 0.0001 + +LR_SCHEDULE: + NAME: "warmup_cosine_decay_schedule" + ARGS: + - warmup_steps: 800 + - max_lr: 0.0001 + - max_epochs: 50 + - decay_type: 1 + +ITERATION_NUM: ~ +EPOCH_NUM: 50 +START_EPOCH: 0 +AMP: True +UPDATE_STRATEGY: "" +UPDATE_FUNCTION: "update_swin_transformer" +EVALUATE_FUNCTION: "swin_transformer_evaluate" + +# Validation details +PER_ITERATION: ~ +PER_EPOCH: ~ +METRICS: + - name: "accuracy(topk=1)" + metric: "accuracy" + top_k: 1 + threshold: ~ + - name: "accuracy(topk=5)" + metric: "accuracy" + top_k: 5 + threshold: ~ + - name: "recall(threshold=0.5)" + metric: "recall" + top_k: ~ + threshold: 0.5 + - name: "precision(threshold=0.5)" + metric: "precision" + top_k: ~ + threshold: 0.5 + \ No newline at end of file diff --git a/configs/transfg/transFG_ViT_B_16.yml b/configs/transfg/transFG_ViT_B_16.yml new file mode 100644 index 0000000..a59551c --- /dev/null +++ b/configs/transfg/transFG_ViT_B_16.yml @@ -0,0 +1,159 @@ +EXP_NAME: "TransFG" + +RESUME_WEIGHT: ~ + +WEIGHT: + NAME: "transFG_ViT_B16.pth" + SAVE_DIR: "/mnt/sdb/data/wangxinran/weight/fgvclib" + +LOGGER: + NAME: "txt_logger" + +DATASET: + ROOT: "/mnt/sdb/data/wangxinran/dataset" + NAME: "CUB_200_2011" + TRAIN: + BATCH_SIZE: 8 + POSITIVE: 0 + PIN_MEMORY: True + SHUFFLE: True + NUM_WORKERS: 4 + TEST: + BATCH_SIZE: 8 + POSITIVE: 0 + PIN_MEMORY: False + SHUFFLE: False + NUM_WORKERS: 4 + +MODEL: + NAME: "TransFG" + CLASS_NUM: 200 + ARGS: + - smoothing_value: 0 + - zero_head: True + - classifier: 'token' + - part_head_in_channels: 768 + - pretrained_weight: "/mnt/sdb/data/wangxinran/weight/pretraining/transFG/ViT-B_16.npz" + + CRITERIONS: + - name: "nll_loss_labelsmoothing" + args: + - smoothing_value: 0.0 + w: 1.0 + - name: "cross_entropy_loss" + args: [] + w: 1.0 + BACKBONE: + NAME: "vit16" + ARGS: + - patch_size: 16 + - image_size: 448 + - split: 'non-overlap' + - slide_step: 12 + - hidden_size: 768 + - representation_size: None + - dropout_rate: 0.1 + ENCODER: + NAME: "transformer_encoder" + ARGS: + - num_layers: 12 + - img_size: 448 + - num_heads: 12 + - attention_dropout_rate: 0.0 + - hidden_size: 768 + - mlp_dim: 3072 + - dropout_rate: 0.1 + - patch_size: 16 + - split: 'non-overlap' + - slide_step: 12 + - mlp_dim: 3072 + NECKS: + NAME: ~ + HEADS: + NAME: "mlp" + ARGS: + - hidden_size: 768 + - mlp_dim: 3072 + - dropout_rate: 0.1 + +TRANSFORMS: + TRAIN: + - name: "resize" + size: + - 600 + - 600 + - name: "random_crop" + size: 448 + padding: 0 + - name: "random_horizontal_flip" + prob: 0.5 + - name: "to_tensor" + - name: "normalize" + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + TEST: + - name: "resize" + size: + - 600 + - 600 + - name: "center_crop" + size: 448 + - name: "to_tensor" + - name: "normalize" + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + +OPTIMIZER: + NAME: SGD + MOMENTUM: 0.9 + WEIGHT_DECAY: 0.0 + LR: + base: 0.08 + backbone: 0.08 + encoder: 0.08 + necks: 0.08 + heads: 0.08 + +EPOCH_NUM: 100 +START_EPOCH: 0 +UPDATE_STRATEGY: "vit_update_strategy" +LR_SCHEDULE: + NAME: "warmup_linear_schedule" + ARGS: + - warmup_steps: 500 + - total_steps: 10000 +AMP: True + + +# Validation details +PER_ITERATION: ~ +PER_EPOCH: ~ +METRICS: + - name: "accuracy(topk=1)" + metric: "accuracy" + top_k: 1 + threshold: ~ + - name: "accuracy(topk=5)" + metric: "accuracy" + top_k: 5 + threshold: ~ + - name: "recall(threshold=0.5)" + metric: "recall" + top_k: ~ + threshold: 0.5 + - name: "precision(threshold=0.5)" + metric: "precision" + top_k: ~ + threshold: 0.5 diff --git a/fgvclib/apis/__init__.py b/fgvclib/apis/__init__.py index 3d13266..962b610 100644 --- a/fgvclib/apis/__init__.py +++ b/fgvclib/apis/__init__.py @@ -1,4 +1,3 @@ from .build import * -from .evaluate_model import evaluate_model -from .update_model import update_model -from .save_model import save_model +from .save_model import * +from .seed import * diff --git a/fgvclib/apis/build.py b/fgvclib/apis/build.py index f3985af..7cbebba 100644 --- a/fgvclib/apis/build.py +++ b/fgvclib/apis/build.py @@ -16,20 +16,23 @@ from fgvclib.configs.utils import turn_list_to_dict as tltd from fgvclib.criterions import get_criterion from fgvclib.datasets import get_dataset -from fgvclib.samplers import get_sampler from fgvclib.datasets.datasets import FGVCDataset +from fgvclib.samplers import get_sampler from fgvclib.metrics import get_metric +from fgvclib.metrics.metrics import NamedMetric from fgvclib.models.sotas import get_model from fgvclib.models.sotas.sota import FGVCSOTA from fgvclib.models.backbones import get_backbone from fgvclib.models.encoders import get_encoder from fgvclib.models.necks import get_neck from fgvclib.models.heads import get_head +from fgvclib.optimizers import get_optimizer from fgvclib.transforms import get_transform from fgvclib.utils.logger import get_logger, Logger from fgvclib.utils.interpreter import get_interpreter, Interpreter from fgvclib.utils.lr_schedules import get_lr_schedule, LRSchedule -from fgvclib.metrics.metrics import NamedMetric +from fgvclib.utils.update_function import get_update_function +from fgvclib.utils.evaluate_function import get_evaluate_function def build_model(model_cfg: CfgNode) -> FGVCSOTA: @@ -64,7 +67,7 @@ def build_model(model_cfg: CfgNode) -> FGVCSOTA: criterions.update({item["name"]: {"fn": build_criterion(item), "w": item["w"]}}) model_builder = get_model(model_cfg.NAME) - model = model_builder(backbone=backbone, encoder=encoder, necks=necks, heads=heads, criterions=criterions) + model = model_builder(cfg=model_cfg, backbone=backbone, encoder=encoder, necks=necks, heads=heads, criterions=criterions) return model @@ -142,28 +145,56 @@ def build_optimizer(optim_cfg: CfgNode, model:t.Union[nn.Module, nn.DataParallel Optimizer: A Pytorch Optimizer. """ - params= list() + params = list() model_attrs = ["backbone", "encoder", "necks", "heads"] - if isinstance(model, nn.DataParallel) or isinstance(model, nn.parallel.DistributedDataParallel): - for attr in model_attrs: - if getattr(model.module, attr) and optim_cfg.LR[attr]: - params.append({ - 'params': getattr(model.module, attr).parameters(), - 'lr': optim_cfg.LR[attr] - }) - print(attr, optim_cfg.LR[attr]) + # if isinstance(model, nn.DataParallel) or isinstance(model, nn.parallel.DistributedDataParallel): + # for attr in model_attrs: + # if getattr(model.module, attr) and optim_cfg.LR[attr]: + # params.append({ + # 'params': getattr(model.module, attr).parameters(), + # 'lr': optim_cfg.LR[attr] + # }) + + # else: + # for attr in model_attrs: + # if getattr(model, attr) and optim_cfg.LR[attr]: + # params.append({ + # 'params': getattr(model, attr).parameters(), + # 'lr': optim_cfg.LR[attr] + # }) + + + if isinstance(model, nn.DataParallel) or isinstance(model, nn.parallel.DistributedDataParallel): + m = model.module else: - for attr in model_attrs: - if getattr(model, attr) and optim_cfg.LR[attr]: + m = model + for n, p in m.named_parameters(): + is_other = True + if p.requires_grad: + for attr in model_attrs: + if n.__contains__(attr): + is_other = False + params.append({ + 'params': p, + 'lr': optim_cfg.LR[attr] + }) + + if is_other: params.append({ - 'params': getattr(model, attr).parameters(), - 'lr': optim_cfg.LR[attr] + 'params': p, + 'lr': optim_cfg.LR["base"] }) + + - optimizer = optim.SGD(params=params, momentum=optim_cfg.MOMENTUM, weight_decay=optim_cfg.WEIGHT_DECAY) + # for n, p in m.named_parameters(): + + # if n.__contains__() + optimizer = get_optimizer(optim_cfg.NAME)(params, optim_cfg.LR.base, tltd(optim_cfg.ARGS)) + # optimizer = AdamW(params=params, lr=0.0001, weight_decay=5e-4) return optimizer def build_criterion(criterion_cfg: CfgNode) -> nn.Module: @@ -220,8 +251,8 @@ def build_sampler(sampler_cfg: CfgNode) -> Sampler: return get_sampler(sampler_cfg.NAME) -def build_lr_schedule(schedule_cfg: CfgNode) -> LRSchedule: - r"""Build metrics for evaluation. +def build_lr_schedule(optimizer, schedule_cfg: CfgNode, train_loader) -> LRSchedule: + r"""Build lr_schedule for training. Args: schedule_cfg (CfgNode): The schedule config node of root config node. @@ -229,6 +260,30 @@ def build_lr_schedule(schedule_cfg: CfgNode) -> LRSchedule: LRSchedule: A lr schedule. """ + batch_num_per_epoch = len(train_loader) + return get_lr_schedule(schedule_cfg.NAME)(optimizer, batch_num_per_epoch, tltd(schedule_cfg.ARGS)) - return get_lr_schedule(schedule_cfg.NAME)(tltd(schedule_cfg.ARGS)) +def build_update_function(cfg): + r"""Build metrics for evaluation. + + Args: + cfg (CfgNode): The root config node. + Returns: + function: A update model function. + + """ + + return get_update_function(cfg.UPDATE_FUNCTION) + + +def build_evaluate_function(cfg): + r"""Build metrics for evaluation. + + Args: + cfg (CfgNode): The root config node. + Returns: + function: A evaluate model function. + + """ + return get_evaluate_function(cfg.EVALUATE_FUNCTION) diff --git a/fgvclib/apis/evaluate_model.py b/fgvclib/apis/evaluate_model.py deleted file mode 100644 index 1eb8552..0000000 --- a/fgvclib/apis/evaluate_model.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) 2022-present, BUPT-PRIS. - -""" - This file provides a api for evaluating FGVC algorithms. -""" - - -import torch -import torch.nn as nn -from torch.autograd import Variable -import typing as t -from fgvclib.metrics.metrics import NamedMetric - - -def evaluate_model(model:nn.Module, p_bar:t.Iterable, metrics:t.List[NamedMetric], use_cuda:bool=True) -> t.Dict: - r"""Evaluate the FGVC model. - - Args: - model (nn.Module): - The FGVC model. - p_bar (iterable): - A iterator provide test data. - metrics (List[NamedMetric]): - List of metrics. - use_cuda (boolean, optional): - Whether to use gpu. - - Returns: - dict: The result dict. - """ - - model.eval() - results = dict() - - with torch.no_grad(): - for _, (inputs, targets) in enumerate(p_bar): - if use_cuda: - inputs, targets = inputs.cuda(), targets.cuda() - inputs, targets = Variable(inputs), Variable(targets) - for metric in metrics: - _ = metric.update(model(inputs), targets) - - for metric in metrics: - result = metric.compute() - results.update({ - metric.name: round(result.item(), 3) - }) - - return results diff --git a/fgvclib/apis/seed.py b/fgvclib/apis/seed.py new file mode 100644 index 0000000..1b5fa96 --- /dev/null +++ b/fgvclib/apis/seed.py @@ -0,0 +1,11 @@ +import random +import numpy as np +import torch + + +def set_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + diff --git a/fgvclib/configs/config.py b/fgvclib/configs/config.py index 17e6778..07d4264 100644 --- a/fgvclib/configs/config.py +++ b/fgvclib/configs/config.py @@ -17,6 +17,9 @@ def __init__(self): # Name of experiment self.cfg.EXP_NAME = None + # Random Seed + self.cfg.SEED = 0 + # Resume last train self.cfg.RESUME_WEIGHT = None @@ -57,7 +60,6 @@ def __init__(self): self.cfg.DATASET.TEST.SHUFFLE = False self.cfg.DATASET.TEST.NUM_WORKERS = 0 - # sampler for dataloader self.cfg.SAMPLER = CN() @@ -72,6 +74,7 @@ def __init__(self): self.cfg.SAMPLER.TEST.ARGS = None self.cfg.SAMPLER.TEST.IS_BATCH_SAMPLER = False + # Model architecture self.cfg.MODEL = CN() self.cfg.MODEL.NAME = None @@ -109,9 +112,10 @@ def __init__(self): # Optimizer self.cfg.OPTIMIZER = CN() self.cfg.OPTIMIZER.NAME = "SGD" - self.cfg.OPTIMIZER.MOMENTUM = 0.9 - self.cfg.OPTIMIZER.WEIGHT_DECAY = 5e-4 + self.cfg.OPTIMIZER.ARGS = [{"momentum": 0.9}, {"weight_decay": 5e-4}] + self.cfg.OPTIMIZER.LR = CN() + self.cfg.OPTIMIZER.LR.base = None self.cfg.OPTIMIZER.LR.backbone = None self.cfg.OPTIMIZER.LR.encoder = None self.cfg.OPTIMIZER.LR.necks = None @@ -121,15 +125,18 @@ def __init__(self): self.cfg.ITERATION_NUM = None self.cfg.EPOCH_NUM = None self.cfg.START_EPOCH = None - self.cfg.UPDATE_STRATEGY = None + self.cfg.UPDATE_FUNCTION = "general_update" + self.cfg.UPDATE_STRATEGY = "general_strategy" self.cfg.LR_SCHEDULE = CN() self.cfg.LR_SCHEDULE.NAME = "cosine_anneal_schedule" self.cfg.LR_SCHEDULE.ARGS = None + self.cfg.AMP = True # Validation self.cfg.PER_ITERATION = None self.cfg.PER_EPOCH = None self.cfg.METRICS = None + self.cfg.EVALUATE_FUNCTION = "general_evaluate" # Inference self.cfg.FIFTYONE = CN() diff --git a/fgvclib/criterions/mutual_channel_loss.py b/fgvclib/criterions/mutual_channel_loss.py index b2f5410..4dfe631 100644 --- a/fgvclib/criterions/mutual_channel_loss.py +++ b/fgvclib/criterions/mutual_channel_loss.py @@ -7,6 +7,7 @@ from torch.nn.modules.utils import _pair from .utils import LossItem +from . import criterion class MutualChannelLoss(nn.Module): @@ -147,7 +148,7 @@ def __repr__(self) -> str: + ', count_include_pad=' + str(self.count_include_pad) + ')' - +@criterion("mutual_channel_loss") def mutual_channel_loss(cfg=None): assert 'height' in cfg.keys(), 'height must exist in parameters' assert 'cnum' in cfg.keys(), 'cnum must exist in parameters' diff --git a/fgvclib/criterions/nll_loss_labelsmoothing.py b/fgvclib/criterions/nll_loss_labelsmoothing.py new file mode 100644 index 0000000..c8a8690 --- /dev/null +++ b/fgvclib/criterions/nll_loss_labelsmoothing.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn + +from . import criterion + + +class LabelSmoothing(nn.Module): + """ + NLL loss with label smoothing. + """ + + def __init__(self, smoothing): + """ + Constructor for the LabelSmoothing module. + :param smoothing: label smoothing factor + """ + super(LabelSmoothing, self).__init__() + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + + def forward(self, x, target): + logprobs = torch.nn.functional.log_softmax(x, dim=-1) + + nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) + nll_loss = nll_loss.squeeze(1) + smooth_loss = -logprobs.mean(dim=-1) + loss = self.confidence * nll_loss + self.smoothing * smooth_loss + return loss.mean() + +@criterion("nll_loss_labelsmoothing") +def nll_loss_labelsmoothing(cfg=None): + return LabelSmoothing(smoothing=cfg['smoothing_value']) diff --git a/fgvclib/metrics/metrics.py b/fgvclib/metrics/metrics.py index 00845ed..8543302 100644 --- a/fgvclib/metrics/metrics.py +++ b/fgvclib/metrics/metrics.py @@ -84,7 +84,7 @@ def recall(name:str="recall(threshold=0.5)", top_k:int=None, threshold:float=0.5 top_k (int): Number of the highest probability or logit score predictions considered finding the correct label. - threshhold (float, optional): + threshold (float, optional): Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. diff --git a/fgvclib/models/__init__.py b/fgvclib/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fgvclib/models/backbones/swin_t.py b/fgvclib/models/backbones/swin_t.py new file mode 100644 index 0000000..e684f5c --- /dev/null +++ b/fgvclib/models/backbones/swin_t.py @@ -0,0 +1,29 @@ +import timm +import torch +from fgvclib.models.backbones import backbone + + +def load_model_weights(model, model_path): + ### reference https://github.com/TACJu/TransFG + ### thanks a lot. + state = torch.load(model_path, map_location='cpu') + for key in model.state_dict(): + if 'num_batches_tracked' in key: + continue + p = model.state_dict()[key] + if key in state['state_dict']: + ip = state['state_dict'][key] + if p.shape == ip.shape: + p.data.copy_(ip.data) # Copy the data of parameters + else: + print('could not load layer: {}, mismatch shape {} ,{}'.format(key, (p.shape), (ip.shape))) + else: + print('could not load layer: {}, not in checkpoint'.format(key)) + return model + + +@backbone("swint") +def swint(cfg): + backbone = timm.create_model('swin_large_patch4_window12_384_in22k', pretrained=cfg['pretrained']) + backbone.train() + return backbone diff --git a/fgvclib/models/backbones/swin_transformer.py b/fgvclib/models/backbones/swin_transformer.py new file mode 100644 index 0000000..9e28040 --- /dev/null +++ b/fgvclib/models/backbones/swin_transformer.py @@ -0,0 +1,706 @@ +""" Swin Transformer +A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` + - https://arxiv.org/pdf/2103.14030 + +Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below + +""" +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- +import logging +import math +from copy import deepcopy +from typing import Optional + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.helpers import build_model_with_cfg +from timm.models.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, lecun_normal_ +from timm.models.vision_transformer import checkpoint_filter_fn + +from . import backbone + +_logger = logging.getLogger(__name__) + + +def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): + """ ViT weight initialization + * When called without n, head_bias, jax_impl args it will behave exactly the same + as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). + * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl + """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + elif name.startswith('pre_logits'): + lecun_normal_(module.weight) + nn.init.zeros_(module.bias) + else: + if jax_impl: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + else: + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif jax_impl and isinstance(module, nn.Conv2d): + # NOTE conv was left to pytorch default in my original init + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): + nn.init.zeros_(module.bias) + nn.init.ones_(module.weight) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models (my experiments) + 'swin_base_patch4_window12_384': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_base_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth', + ), + + 'swin_large_patch4_window12_384': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_large_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth', + ), + + 'swin_small_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth', + ), + + 'swin_tiny_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth', + ), + + 'swin_base_patch4_window12_384_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth', + input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841), + + 'swin_base_patch4_window7_224_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth', + num_classes=21841), + + 'swin_large_patch4_window12_384_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth', + input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841), + + 'swin_large_patch4_window7_224_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth', + num_classes=21841), + +} + + +def window_partition(x, window_size: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, + attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + # H *= 2 + # W *= 2 + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + # H *= 2 + # W *= 2 + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if not torch.jit.is_scripting() and self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), + window_size=7, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, weight_init='', **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + self.patch_grid = self.patch_embed.grid_size + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + else: + self.absolute_pos_embed = None + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + layers = [] + for i_layer in range(self.num_layers): + layers += [BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + input_resolution=(self.patch_grid[0] // (2 ** i_layer), self.patch_grid[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + ] + self.layers = nn.Sequential(*layers) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. + if weight_init.startswith('jax'): + for n, m in self.named_modules(): + _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True) + else: + self.apply(_init_vit_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + if self.absolute_pos_embed is not None: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + l1 = self.layers[0](x) + l2 = self.layers[1](l1) + l3 = self.layers[2](l2) + l4 = self.layers[3](l3) + # x = self.norm(l4) # B L C + # x = self.avgpool(x.transpose(1, 2)) # B C 1 + # x = torch.flatten(x, 1) + return {"layer1":l1, "layer2":l2, "layer3":l3, "layer4":l4} + + def forward(self, x): + x = self.forward_features(x) + # x = self.head(x) + return x + +def overlay_external_default_cfg(default_cfg, kwargs): + """ Overlay 'external_default_cfg' in kwargs on top of default_cfg arg. + """ + external_default_cfg = kwargs.pop('external_default_cfg', None) + if external_default_cfg: + default_cfg.pop('url', None) # url should come from external cfg + default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg + default_cfg.update(external_default_cfg) + + +def _create_swin_transformer(variant, cfg, default_cfg=None, **kwargs): + pretrained = False if "pretrained" not in cfg.keys() else cfg['pretrained'] + if default_cfg is None: + default_cfg = deepcopy(default_cfgs[variant]) + overlay_external_default_cfg(default_cfg, kwargs) + default_num_classes = default_cfg['num_classes'] + default_img_size = default_cfg['input_size'][-2:] + + num_classes = kwargs.pop('num_classes', default_num_classes) + img_size = kwargs.pop('img_size', default_img_size) + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + SwinTransformer, variant, pretrained, + default_cfg=default_cfg, + img_size=img_size, + num_classes=num_classes, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + + return model + + +@backbone('swin_base_patch4_window12_384') +def swin_base_patch4_window12_384(cfg, **kwargs): + """ Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer('swin_base_patch4_window12_384', cfg, **model_kwargs) + + +@backbone('swin_base_patch4_window7_224') +def swin_base_patch4_window7_224(cfg, **kwargs): + """ Swin-B @ 224x224, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer('swin_base_patch4_window7_224', cfg, **model_kwargs) + + +@backbone('swin_large_patch4_window12_384') +def swin_large_patch4_window12_384(cfg, **kwargs): + """ Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer('swin_large_patch4_window12_384', cfg, **model_kwargs) + + +@backbone('swin_large_patch4_window7_224') +def swin_large_patch4_window7_224(cfg, **kwargs): + """ Swin-L @ 224x224, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer('swin_large_patch4_window7_224', cfg, **model_kwargs) + + +@backbone('swin_small_patch4_window7_224') +def swin_small_patch4_window7_224(cfg, **kwargs): + """ Swin-S @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer('swin_small_patch4_window7_224', cfg, **model_kwargs) + + +@backbone('swin_tiny_patch4_window7_224') +def swin_tiny_patch4_window7_224(cfg, **kwargs): + """ Swin-T @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer('swin_tiny_patch4_window7_224', cfg, **model_kwargs) + + +@backbone('swin_base_patch4_window12_384_in22k') +def swin_base_patch4_window12_384_in22k(cfg, **kwargs): + """ Swin-B @ 384x384, trained ImageNet-22k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer('swin_base_patch4_window12_384_in22k', cfg, **model_kwargs) + + +@backbone('swin_base_patch4_window7_224_in22k') +def swin_base_patch4_window7_224_in22k(cfg, **kwargs): + """ Swin-B @ 224x224, trained ImageNet-22k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer('swin_base_patch4_window7_224_in22k', cfg, **model_kwargs) + + +@backbone('swin_large_patch4_window12_384_in22k') +def swin_large_patch4_window12_384_in22k(cfg, **kwargs): + """ Swin-L @ 384x384, trained ImageNet-22k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer('swin_large_patch4_window12_384_in22k', cfg, **model_kwargs) + + +@backbone('swin_large_patch4_window7_224_in22k') +def swin_large_patch4_window7_224_in22k(cfg, **kwargs): + """ Swin-L @ 224x224, trained ImageNet-22k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer('swin_large_patch4_window7_224_in22k', cfg, **model_kwargs) + \ No newline at end of file diff --git a/fgvclib/models/backbones/vit.py b/fgvclib/models/backbones/vit.py new file mode 100644 index 0000000..23c328c --- /dev/null +++ b/fgvclib/models/backbones/vit.py @@ -0,0 +1,80 @@ +import logging +import torch +import torch.nn as nn + +from torch.nn import Dropout, Conv2d +from torch.nn.modules.utils import _pair +from fgvclib.models.backbones import backbone + +# official pretrain weights +model_urls = { + 'ViT-B_16': 'https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz', + 'ViT-B_32': 'https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_32.npz', + 'ViT-L_16': 'https://storage.googleapis.com/vit_models/imagenet21k/ViT-L_16.npz', + 'ViT-L_32': 'https://storage.googleapis.com/vit_models/imagenet21k/ViT-L_32.npz', + 'ViT-H_14': 'https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz' +} + +cfgs = { + 'ViT-B_16', + 'ViT-B_32', + 'ViT-L_16', + 'ViT-L_32', + 'ViT-H_14' +} + +logger = logging.getLogger(__name__) + + +class Embeddings(nn.Module): + """Construct the embeddings from patch, position embeddings. + """ + + def __init__(self, cfg: dict, img_size, in_channels=3): + super(Embeddings, self).__init__() + self.hybrid = None + img_size = _pair(img_size) + + patch_size = cfg["patch_size"] + if cfg['split'] == 'non-overlap': + n_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) + self.patch_embeddings = Conv2d(in_channels=in_channels, + out_channels=cfg['hidden_size'], + kernel_size=patch_size, + stride=patch_size) + elif cfg['split'] == 'overlap': + n_patches = ((img_size[0] - patch_size) // cfg['slide_step'] + 1) * ( + (img_size[1] - patch_size) // cfg['slide_step'] + 1) + self.patch_embeddings = Conv2d(in_channels=in_channels, + out_channels=cfg['hidden_size'], + kernel_size=patch_size, + stride=(cfg['slide_step'], cfg['slide_step'])) + self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches + 1, cfg['hidden_size'])) + self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg['hidden_size'])) + + self.dropout = Dropout(cfg['dropout_rate']) + + def forward(self, x): + B = x.shape[0] + cls_tokens = self.cls_token.expand(B, -1, -1) + + if self.hybrid: + x = self.hybrid_model(x) + x = self.patch_embeddings(x) + x = x.flatten(2) + x = x.transpose(-1, -2) + x = torch.cat((cls_tokens, x), dim=1) + + embeddings = x + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings + + +def _vit(cfg, model_name="ViT-B_16", **kwargs): + assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name) + return Embeddings(cfg, img_size=cfg['image_size']) + + +@backbone("vit16") +def vit16(cfg): + return _vit(cfg, "ViT-B_16") diff --git a/fgvclib/models/encoders/fpn.py b/fgvclib/models/encoders/fpn.py new file mode 100644 index 0000000..53ea3f7 --- /dev/null +++ b/fgvclib/models/encoders/fpn.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn + +from fgvclib.models.encoders import encoder + + +class FPN(nn.Module): + + def __init__(self, inputs: dict, fpn_size: int, proj_type: str, upsample_type: str): + """ + inputs : dictionary contains torch.Tensor + which comes from backbone output + fpn_size: integer, fpn + proj_type: + in ["Conv", "Linear"] + upsample_type: + in ["Bilinear", "Conv", "Fc"] + for convolution neural network (e.g. ResNet, EfficientNet), recommand 'Bilinear'. + for Vit, "Fc". and Swin-T, "Conv" + """ + super(FPN, self).__init__() + assert proj_type in ["Conv", "Linear"], \ + "FPN projection type {} were not support yet, please choose type 'Conv' or 'Linear'".format(proj_type) + assert upsample_type in ["Bilinear", "Conv"], \ + "FPN upsample type {} were not support yet, please choose type 'Bilinear' or 'Conv'".format(proj_type) + + self.fpn_size = fpn_size + self.upsample_type = upsample_type + inp_names = [name for name in inputs] + + for i, node_name in enumerate(inputs): + ### projection module + if proj_type == "Conv": + m = nn.Sequential( + nn.Conv2d(inputs[node_name].size(1), inputs[node_name].size(1), 1), + nn.ReLU(), + nn.Conv2d(inputs[node_name].size(1), fpn_size, 1) + ) + elif proj_type == "Linear": + m = nn.Sequential( + nn.Linear(inputs[node_name].size(-1), inputs[node_name].size(-1)), + nn.ReLU(), + nn.Linear(inputs[node_name].size(-1), fpn_size), + ) + self.add_module("Proj_" + node_name, m) + + ### upsample module + if upsample_type == "Conv" and i != 0: + assert len(inputs[node_name].size()) == 3 # B, S, C + in_dim = inputs[node_name].size(1) + out_dim = inputs[inp_names[i - 1]].size(1) + if in_dim != out_dim: + m = nn.Conv1d(in_dim, out_dim, 1) # for spatial domain + else: + m = nn.Identity() + self.add_module("Up_" + node_name, m) + + if upsample_type == "Bilinear": + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') + + def upsample_add(self, x0: torch.Tensor, x1: torch.Tensor, x1_name: str): + """ + return Upsample(x1) + x1 + """ + if self.upsample_type == "Bilinear": + if x1.size(-1) != x0.size(-1): + x1 = self.upsample(x1) + else: + x1 = getattr(self, "Up_" + x1_name)(x1) + return x1 + x0 + + def forward(self, x): + """ + x : dictionary + { + "node_name1": feature1, + "node_name2": feature2, ... + } + """ + ### project to same dimension + hs = [] + for i, name in enumerate(x): + x[name] = getattr(self, "Proj_" + name)(x[name]) + hs.append(name) + + for i in range(len(hs) - 1, 0, -1): + x1_name = hs[i] + x0_name = hs[i - 1] + x[x0_name] = self.upsample_add(x[x0_name], + x[x1_name], + x1_name) + return x + + +@encoder("fpn") +def fpn(inputs, cfg: dict): + return FPN(inputs=inputs, fpn_size=cfg['fpn_size'], proj_type=cfg['proj_type'], upsample_type=cfg['upsample_type']) diff --git a/fgvclib/models/encoders/transformer_encoder.py b/fgvclib/models/encoders/transformer_encoder.py new file mode 100644 index 0000000..07d9b9d --- /dev/null +++ b/fgvclib/models/encoders/transformer_encoder.py @@ -0,0 +1,203 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import copy +from ..backbones.vit import Embeddings +import logging +import math +from os.path import join as pjoin +import torch +import torch.nn as nn +from torch.nn import Dropout, Softmax, Linear, LayerNorm + + +from fgvclib.models.encoders import encoder +from fgvclib.models.heads.mlp import mlp + +logger = logging.getLogger(__name__) + +ATTENTION_Q = "MultiHeadDotProductAttention_1/query" +ATTENTION_K = "MultiHeadDotProductAttention_1/key" +ATTENTION_V = "MultiHeadDotProductAttention_1/value" +ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" +FC_0 = "MlpBlock_3/Dense_0" +FC_1 = "MlpBlock_3/Dense_1" +ATTENTION_NORM = "LayerNorm_0" +MLP_NORM = "LayerNorm_2" + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +class Attention(nn.Module): + def __init__(self, cfg: dict): + super(Attention, self).__init__() + self.num_attention_heads = cfg['num_heads'] + self.attention_head_size = int(cfg['hidden_size'] / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Linear(cfg['hidden_size'], self.all_head_size) + self.key = Linear(cfg['hidden_size'], self.all_head_size) + self.value = Linear(cfg['hidden_size'], self.all_head_size) + + self.out = Linear(cfg['hidden_size'], cfg['hidden_size']) + self.attn_dropout = Dropout(cfg['attention_dropout_rate']) + self.proj_dropout = Dropout(cfg['attention_dropout_rate']) + + self.softmax = Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.softmax(attention_scores) + weights = attention_probs + attention_probs = self.attn_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + attention_output = self.out(context_layer) + attention_output = self.proj_dropout(attention_output) + return attention_output, weights + + +class Block(nn.Module): + def __init__(self, cfg: dict): + super(Block, self).__init__() + self.hidden_size = cfg['hidden_size'] + self.attention_norm = LayerNorm(cfg['hidden_size'], eps=1e-6) + self.ffn_norm = LayerNorm(cfg['hidden_size'], eps=1e-6) + self.ffn = mlp(cfg) + self.attn = Attention(cfg) + + def forward(self, x): + h = x + x = self.attention_norm(x) + x, weights = self.attn(x) + x = x + h + + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + return x, weights + + def load_from(self, weights, n_block): + ROOT = f"Transformer/encoderblock_{n_block}" + with torch.no_grad(): + query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, + self.hidden_size).t() + key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() + value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, + self.hidden_size).t() + out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, + self.hidden_size).t() + + query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) + key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) + value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) + out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) + + self.attn.query.weight.copy_(query_weight) + self.attn.key.weight.copy_(key_weight) + self.attn.value.weight.copy_(value_weight) + self.attn.out.weight.copy_(out_weight) + self.attn.query.bias.copy_(query_bias) + self.attn.key.bias.copy_(key_bias) + self.attn.value.bias.copy_(value_bias) + self.attn.out.bias.copy_(out_bias) + + mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() + mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() + mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() + mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() + + self.ffn.fc1.weight.copy_(mlp_weight_0) + self.ffn.fc2.weight.copy_(mlp_weight_1) + self.ffn.fc1.bias.copy_(mlp_bias_0) + self.ffn.fc2.bias.copy_(mlp_bias_1) + + self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) + self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) + self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) + self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) + + +class Part_Attention(nn.Module): + def __init__(self): + super(Part_Attention, self).__init__() + + def forward(self, x): + length = len(x) + last_map = x[0] + for i in range(1, length): + last_map = torch.matmul(x[i], last_map) + last_map = last_map[:, :, 0, 1:] + + _, max_inx = last_map.max(2) + return _, max_inx + + +class Encoder(nn.Module): + def __init__(self, cfg: dict): + super(Encoder, self).__init__() + self.layer = nn.ModuleList() + for _ in range(cfg['num_layers'] - 1): + layer = Block(cfg) + self.layer.append(copy.deepcopy(layer)) + self.part_select = Part_Attention() + self.part_layer = Block(cfg) + self.part_norm = LayerNorm(cfg['hidden_size'], eps=1e-6) + + def forward(self, hidden_states): + attn_weights = [] + for layer in self.layer: + hidden_states, weights = layer(hidden_states) + attn_weights.append(weights) + part_num, part_inx = self.part_select(attn_weights) + part_inx = part_inx + 1 + parts = [] + B, num = part_inx.shape + for i in range(B): + parts.append(hidden_states[i, part_inx[i, :]]) + parts = torch.stack(parts).squeeze(1) + concat = torch.cat((hidden_states[:, 0].unsqueeze(1), parts), dim=1) + part_states, part_weights = self.part_layer(concat) + part_encoded = self.part_norm(part_states) + + return part_encoded + + +class Transformer(nn.Module): + def __init__(self, cfg: dict, img_size): + super(Transformer, self).__init__() + self.embeddings = Embeddings(cfg, img_size=img_size) + self.encoder = Encoder(cfg) + + def forward(self, input_ids): + embedding_output = self.embeddings(input_ids) + part_encoded = self.encoder(embedding_output) + return part_encoded + + +@encoder("transformer_encoder") +def transformer_encoder(cfg: dict): + return Transformer(cfg, img_size=cfg['img_size']) diff --git a/fgvclib/models/heads/gcn_combiner.py b/fgvclib/models/heads/gcn_combiner.py new file mode 100644 index 0000000..aa10a6b --- /dev/null +++ b/fgvclib/models/heads/gcn_combiner.py @@ -0,0 +1,102 @@ +import torch +import torch.nn as nn +from typing import Union +import copy + +from fgvclib.models.heads import head + + +class GCNCombiner(nn.Module): + + def __init__(self, + total_num_selects: int, + num_classes: int, + inputs: Union[dict, None] = None, + proj_size: Union[int, None] = None, + fpn_size: Union[int, None] = None): + """ + If building backbone without FPN, set fpn_size to None and MUST give + 'inputs' and 'proj_size', the reason of these setting is to constrain the + dimension of graph convolutional network input. + """ + super(GCNCombiner, self).__init__() + + assert inputs is not None or fpn_size is not None, \ + "To build GCN combiner, you must give one features dimension." + + ### auto-proj + self.fpn_size = fpn_size + if fpn_size is None: + for name in inputs: + if len(name) == 4: + in_size = inputs[name].size(1) + elif len(name) == 3: + in_size = inputs[name].size(2) + else: + raise ValueError("The size of output dimension of previous must be 3 or 4.") + m = nn.Sequential( + nn.Linear(in_size, proj_size), + nn.ReLU(), + nn.Linear(proj_size, proj_size) + ) + self.add_module("proj_" + name, m) + self.proj_size = proj_size + else: + self.proj_size = fpn_size + + ### build one layer structure (with adaptive module) + num_joints = total_num_selects // 32 + + self.param_pool0 = nn.Linear(total_num_selects, num_joints) + + A = torch.eye(num_joints) / 100 + 1 / 100 + self.adj1 = nn.Parameter(copy.deepcopy(A)) + self.conv1 = nn.Conv1d(self.proj_size, self.proj_size, 1) + self.batch_norm1 = nn.BatchNorm1d(self.proj_size) + + self.conv_q1 = nn.Conv1d(self.proj_size, self.proj_size // 4, 1) + self.conv_k1 = nn.Conv1d(self.proj_size, self.proj_size // 4, 1) + self.alpha1 = nn.Parameter(torch.zeros(1)) + + ### merge information + self.param_pool1 = nn.Linear(num_joints, 1) + + #### class predict + self.dropout = nn.Dropout(p=0.1) + self.classifier = nn.Linear(self.proj_size, num_classes) + + self.tanh = nn.Tanh() + + def forward(self, x): + """ + """ + hs = [] + for name in x: + if self.fpn_size is None: + hs.append(getattr(self, "proj_" + name)(x[name])) + else: + hs.append(x[name]) + hs = torch.cat(hs, dim=1).transpose(1, 2).contiguous() # B, S', C --> B, C, S + hs = self.param_pool0(hs) + ### adaptive adjacency + q1 = self.conv_q1(hs).mean(1) + k1 = self.conv_k1(hs).mean(1) + A1 = self.tanh(q1.unsqueeze(-1) - k1.unsqueeze(1)) + A1 = self.adj1 + A1 * self.alpha1 + ### graph convolution + hs = self.conv1(hs) + hs = torch.matmul(hs, A1) + hs = self.batch_norm1(hs) + ### predict + hs = self.param_pool1(hs) + hs = self.dropout(hs) + hs = hs.flatten(1) + hs = self.classifier(hs) + + return hs + + +@head("GCN_combiner") +def GCN_combiner(cfg: dict, **kwargs): + return GCNCombiner(total_num_selects=cfg['total_num_selects'], num_classes=cfg['num_classes'], + fpn_size=cfg['fpn_size']) diff --git a/fgvclib/models/heads/mlp.py b/fgvclib/models/heads/mlp.py new file mode 100644 index 0000000..1dbc12c --- /dev/null +++ b/fgvclib/models/heads/mlp.py @@ -0,0 +1,49 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +from torch.nn import Dropout, Linear + +from fgvclib.models.heads import head + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class Mlp(nn.Module): + def __init__(self, hidden_size, mlp_dim, dropout_rate): + super(Mlp, self).__init__() + self.fc1 = Linear(hidden_size, mlp_dim) + self.fc2 = Linear(mlp_dim, hidden_size) + self.act_fn = ACT2FN["gelu"] + self.dropout = Dropout(dropout_rate) + + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x = self.act_fn(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +@head("mlp") +def mlp(cfg: dict, **kwargs) -> Mlp: + assert 'hidden_size' in cfg.keys() + assert 'mlp_dim' in cfg.keys() + assert 'dropout_rate' in cfg.keys() + return Mlp(hidden_size=cfg['hidden_size'], mlp_dim=cfg['mlp_dim'], dropout_rate=cfg['dropout_rate']) diff --git a/fgvclib/models/necks/weakly_selector.py b/fgvclib/models/necks/weakly_selector.py new file mode 100644 index 0000000..0390c4a --- /dev/null +++ b/fgvclib/models/necks/weakly_selector.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +from typing import Union + +from fgvclib.models.necks import neck + + +class WeaklySelector(nn.Module): + + def __init__(self, inputs: dict, num_classes: int, num_select: dict, fpn_size: Union[int, None] = None): + """ + inputs: dictionary contain torch.Tensors, which comes from backbone + [Tensor1(hidden feature1), Tensor2(hidden feature2)...] + Please note that if len(features.size) equal to 3, the order of dimension must be [B,S,C], + S mean the spatial domain, and if len(features.size) equal to 4, the order must be [B,C,H,W] + + """ + super(WeaklySelector, self).__init__() + + self.num_select = num_select + + self.fpn_size = fpn_size + ### build classifier + if self.fpn_size is None: + self.num_classes = num_classes + for name in inputs: + fs_size = inputs[name].size() + if len(fs_size) == 3: + in_size = fs_size[2] + elif len(fs_size) == 4: + in_size = fs_size[1] + m = nn.Linear(in_size, num_classes) + self.add_module("classifier_l_" + name, m) + + def forward(self, x, logits=None): + """ + x : + dictionary contain the features maps which + come from your choosen layers. + size must be [B, HxW, C] ([B, S, C]) or [B, C, H, W]. + [B,C,H,W] will be transpose to [B, HxW, C] automatically. + """ + i = 0 + if self.fpn_size is None: + logits = {} + selections = {} + for name in x: + if len(x[name].size()) == 4: + B, C, H, W = x[name].size() + x[name] = x[name].view(B, C, H * W).permute(0, 2, 1).contiguous() + C = x[name].size(-1) + if self.fpn_size is None: + logits[name] = getattr(self, "classifier_l_" + name)(x[name]) + + probs = torch.softmax(logits[name], dim=-1) + selections[name] = [] + preds_1 = [] + preds_0 = [] + num_select = self.num_select[i][name] + i = i + 1 + for bi in range(logits[name].size(0)): + max_ids, _ = torch.max(probs[bi], dim=-1) + confs, ranks = torch.sort(max_ids, descending=True) + sf = x[name][bi][ranks[:num_select]] + nf = x[name][bi][ranks[num_select:]] # calculate + selections[name].append(sf) # [num_selected, C] + preds_1.append(logits[name][bi][ranks[:num_select]]) + preds_0.append(logits[name][bi][ranks[num_select:]]) + + selections[name] = torch.stack(selections[name]) + preds_1 = torch.stack(preds_1) + preds_0 = torch.stack(preds_0) + + logits["select_" + name] = preds_1 + logits["drop_" + name] = preds_0 + + return selections + + +@neck("weakly_selector") +def weakly_selector(inputs, cfg: dict): + return WeaklySelector(inputs=inputs, num_classes=cfg['num_classes'], num_select=cfg['num_selects'], + fpn_size=cfg['fpn_size']) diff --git a/fgvclib/models/sotas/api_net.py b/fgvclib/models/sotas/api_net.py index badf664..5e50580 100644 --- a/fgvclib/models/sotas/api_net.py +++ b/fgvclib/models/sotas/api_net.py @@ -14,8 +14,8 @@ class APINet(FGVCSOTA): Link: https://github.com/PeiqinZhuang/API-Net """ - def __init__(self, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: dict): - super().__init__(backbone, encoder, necks, heads, criterions) + def __init__(self, cfg: dict, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): + super().__init__(cfg, backbone, encoder, necks, heads, criterions) def forward(self, images, targets=None): diff --git a/fgvclib/models/sotas/cal.py b/fgvclib/models/sotas/cal.py index 466ca9d..1ccc497 100644 --- a/fgvclib/models/sotas/cal.py +++ b/fgvclib/models/sotas/cal.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from yacs.config import CfgNode from .sota import FGVCSOTA from fgvclib.criterions.utils import LossItem @@ -67,13 +68,14 @@ class WSDAN_CAL(FGVCSOTA): Link: https://github.com/raoyongming/CAL """ - def __init__(self, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: dict): - super().__init__(backbone, encoder, necks, heads, criterions) - self.num_classes = 200 + def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): + super().__init__(cfg, backbone, encoder, necks, heads, criterions) + + self.out_channels = 2048 + self.num_classes = cfg.CLASS_NUM self.M = 32 self.net = 'resnet101' - - self.register_buffer('feature_center', torch.zeros(self.num_classes, self.M * 2048)) # 32 * 2048 + self.register_buffer('feature_center', torch.zeros(self.num_classes, self.M * self.out_channels)) # 32 * 2048 def infer(self, x): @@ -168,3 +170,36 @@ def load_state_dict(self, state_dict, strict=True): model_dict.update(pretrained_dict) super(WSDAN_CAL, self).load_state_dict(model_dict) + + def infer_aux(self, x): + x_m = torch.flip(x, [3]) + + y_pred_raw, y_pred_aux_raw, _, attention_map = self.infer(x) + y_pred_raw_m, y_pred_aux_raw_m, _, attention_map_m = self.infer(x_m) + crop_images = batch_augment(x, attention_map, mode='crop', theta=0.3, padding_ratio=0.1) + y_pred_crop, y_pred_aux_crop, _, _ = self.infer(crop_images) + + crop_images2 = batch_augment(x, attention_map, mode='crop', theta=0.2, padding_ratio=0.1) + y_pred_crop2, y_pred_aux_crop2, _, _ = self.infer(crop_images2) + + crop_images3 = batch_augment(x, attention_map, mode='crop', theta=0.1, padding_ratio=0.05) + y_pred_crop3, y_pred_aux_crop3, _, _ = self.infer(crop_images3) + + crop_images_m = batch_augment(x_m, attention_map_m, mode='crop', theta=0.3, padding_ratio=0.1) + y_pred_crop_m, y_pred_aux_crop_m, _, _ = self.infer(crop_images_m) + + crop_images_m2 = batch_augment(x_m, attention_map_m, mode='crop', theta=0.2, padding_ratio=0.1) + y_pred_crop_m2, y_pred_aux_crop_m2, _, _ = self.infer(crop_images_m2) + + crop_images_m3 = batch_augment(x_m, attention_map_m, mode='crop', theta=0.1, padding_ratio=0.05) + y_pred_crop_m3, y_pred_aux_crop_m3, _, _ = self.infer(crop_images_m3) + + y_pred = (y_pred_raw + y_pred_crop + y_pred_crop2 + y_pred_crop3) / 4. + y_pred_m = (y_pred_raw_m + y_pred_crop_m + y_pred_crop_m2 + y_pred_crop_m3) / 4. + y_pred = (y_pred + y_pred_m) / 2. + + y_pred_aux = (y_pred_aux_raw + y_pred_aux_crop + y_pred_aux_crop2 + y_pred_aux_crop3) / 4. + y_pred_aux_m = (y_pred_aux_raw_m + y_pred_aux_crop_m + y_pred_aux_crop_m2 + y_pred_aux_crop_m3) / 4. + y_pred_aux = (y_pred_aux + y_pred_aux_m) / 2. + + return y_pred_aux diff --git a/fgvclib/models/sotas/mcl.py b/fgvclib/models/sotas/mcl.py index ffde2a3..e181fc4 100644 --- a/fgvclib/models/sotas/mcl.py +++ b/fgvclib/models/sotas/mcl.py @@ -1,5 +1,6 @@ from torch import nn +from yacs.config import CfgNode from fgvclib.models.sotas.sota import FGVCSOTA from fgvclib.models.sotas import fgvcmodel @@ -12,14 +13,16 @@ class MCL(FGVCSOTA): Link: https://github.com/PRIS-CV/Mutual-Channel-Loss """ - def __init__(self, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): - super().__init__(backbone, encoder, necks, heads, criterions) + def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): + super().__init__(cfg, backbone, encoder, necks, heads, criterions) + self.num_classes = cfg.CLASS_NUM + def forward(self, x, targets=None): x = self.backbone(x) if self.training: losses = list() - losses.extend(self.criterions['mutual_channel_loss']['fn'](x, targets, self.heads.get_class_num())) + losses.extend(self.criterions['mutual_channel_loss']['fn'](x, targets, self.num_classes)) x = self.encoder(x) x = x.view(x.size(0), -1) diff --git a/fgvclib/models/sotas/pmg.py b/fgvclib/models/sotas/pmg.py index 5aa394f..e8f1f7c 100644 --- a/fgvclib/models/sotas/pmg.py +++ b/fgvclib/models/sotas/pmg.py @@ -1,5 +1,6 @@ import torch.nn as nn import random +from yacs.config import CfgNode from fgvclib.models.sotas.sota import FGVCSOTA from fgvclib.criterions.utils import LossItem @@ -14,8 +15,8 @@ class PMG(FGVCSOTA): """ - def __init__(self, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): - super().__init__(backbone, encoder, necks, heads, criterions) + def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): + super().__init__(cfg, backbone, encoder, necks, heads, criterions) self.outputs_num = 4 diff --git a/fgvclib/models/sotas/pmg_v2.py b/fgvclib/models/sotas/pmg_v2.py index b124c7b..10d58b5 100644 --- a/fgvclib/models/sotas/pmg_v2.py +++ b/fgvclib/models/sotas/pmg_v2.py @@ -1,4 +1,5 @@ import torch.nn as nn +from yacs.config import CfgNode from fgvclib.models.sotas.sota import FGVCSOTA from fgvclib.criterions import LossItem @@ -16,11 +17,12 @@ class PMG_V2(FGVCSOTA): BLOCKS = [[8, 8, 0, 0], [4, 4, 4, 0], [2, 2, 2, 2]] alpha = [0.01, 0.05, 0.1] - def __init__(self, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): - super().__init__(backbone, encoder, necks, heads, criterions) - - self.outputs_num = 3 - + def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): + super().__init__(cfg, backbone, encoder, necks, heads, criterions) + self.BLOCKS = self.args["BLOCKS"] + self.outputs_num = self.args["outputs_num"] + self.alpha = self.args["alpha"] + def forward(self, x, targets=None, step:int=None, batch_size:int=None): if step is not None: diff --git a/fgvclib/models/sotas/resnet50.py b/fgvclib/models/sotas/resnet50.py index 0769cc7..e9f8e62 100644 --- a/fgvclib/models/sotas/resnet50.py +++ b/fgvclib/models/sotas/resnet50.py @@ -1,4 +1,5 @@ import torch.nn as nn +from yacs.config import CfgNode from fgvclib.models.sotas.sota import FGVCSOTA from fgvclib.models.sotas import fgvcmodel @@ -9,8 +10,8 @@ @fgvcmodel("ResNet50") class ResNet50(FGVCSOTA): - def __init__(self, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: dict): - super().__init__(backbone, encoder, necks, heads, criterions) + def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): + super().__init__(cfg, backbone, encoder, necks, heads, criterions) def forward(self, x, targets=None): x = self.infer(x) diff --git a/fgvclib/models/sotas/sota.py b/fgvclib/models/sotas/sota.py index 5e85845..2145056 100644 --- a/fgvclib/models/sotas/sota.py +++ b/fgvclib/models/sotas/sota.py @@ -1,17 +1,22 @@ -import torch +from yacs.config import CfgNode import torch.nn as nn from thop import profile +from fgvclib.configs.utils import turn_list_to_dict as tltd + + class FGVCSOTA(nn.Module): - def __init__(self, backbone:nn.Module, encoder:nn.Module, necks:nn.Module, heads:nn.Module, criterions:nn.Module): + def __init__(self, cfg:CfgNode, backbone:nn.Module, encoder:nn.Module, necks:nn.Module, heads:nn.Module, criterions:nn.Module): super(FGVCSOTA, self).__init__() - + + self.cfg = cfg self.backbone = backbone self.necks = necks self.encoder = encoder self.heads = heads self.criterions = criterions + self.args = tltd(cfg.ARGS) def get_structure(self): ss = f"\n{'=' * 30}\n" + f"\nThe Structure of {self.__class__.__name__}:\n" + f"\n{'=' * 30}\n\n" diff --git a/fgvclib/models/sotas/swin_transformer.py b/fgvclib/models/sotas/swin_transformer.py new file mode 100644 index 0000000..ac204e9 --- /dev/null +++ b/fgvclib/models/sotas/swin_transformer.py @@ -0,0 +1,187 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from yacs.config import CfgNode + +from .sota import FGVCSOTA +from fgvclib.models.sotas import fgvcmodel +from fgvclib.criterions.utils import LossItem +from fgvclib.models.encoders.fpn import FPN +from fgvclib.models.necks.weakly_selector import WeaklySelector + +@fgvcmodel("SwinTransformer") +class SwinTransformer(FGVCSOTA): + def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): + super().__init__(cfg, backbone, encoder, necks, heads, criterions) + + ### get hidden feartues size + + + self.num_classes = cfg.CLASS_NUM + self.use_fpn = self.args["use_fpn"] + self.lambda_s = self.args["lambda_s"] + self.lambda_n = self.args["lambda_n"] + self.lambda_b = self.args["lambda_b"] + self.lambda_c = self.args["lambda_c"] + self.use_combiner = self.args["use_combiner"] + self.update_freq = self.args["update_freq"] + self.use_selection = self.args["use_selection"] + num_select = self.args["num_select"] + + input_size = self.args["img_size"] + rand_in = torch.randn(1, 3, input_size, input_size) + backbone_outs = self.backbone(rand_in) + + if self.use_fpn: + + fpn_size = self.args["fpn_size"] + self.encoder = FPN(inputs=backbone_outs, fpn_size=fpn_size, proj_type="Linear", upsample_type="Conv") + fpn_outs = self.encoder(backbone_outs) + else: + fpn_outs = backbone_outs + fpn_size = None + + + self.necks = WeaklySelector(inputs=fpn_outs, num_classes=self.num_classes, num_select=num_select, fpn_size=fpn_size) + + ### = = = = = FPN = = = = = + self.fpn = self.encoder + + self.build_fpn_classifier(backbone_outs, fpn_size, self.num_classes) + + ### = = = = = Selector = = = = = + self.selector = self.necks + + ### = = = = = Combiner = = = = = + self.combiner = self.heads + + ### just original backbone + if not self.fpn and (not self.combiner): + for name in backbone_outs: + fs_size = backbone_outs[name].size() + if len(fs_size) == 3: + out_size = fs_size.size(-1) + elif len(fs_size) == 4: + out_size = fs_size.size(1) + else: + raise ValueError("The size of output dimension of previous must be 3 or 4.") + self.classifier = nn.Linear(out_size, self.num_classes) + + def build_fpn_classifier(self, inputs: dict, fpn_size: int, num_classes: int): + for name in inputs: + m = nn.Sequential( + nn.Conv1d(fpn_size, fpn_size, 1), + nn.BatchNorm1d(fpn_size), + nn.ReLU(), + nn.Conv1d(fpn_size, num_classes, 1) + ) + self.add_module("fpn_classifier_" + name, m) + + def forward_backbone(self, x): + return self.backbone(x) + + def fpn_predict(self, x: dict, logits: dict): + """ + x: [B, C, H, W] or [B, S, C] + [B, C, H, W] --> [B, H*W, C] + """ + for name in x: + ### predict on each features point + if len(x[name].size()) == 4: + B, C, H, W = x[name].size() + logit = x[name].view(B, C, H * W) + elif len(x[name].size()) == 3: + logit = x[name].transpose(1, 2).contiguous() + logits[name] = getattr(self, "fpn_classifier_" + name)(logit) + logits[name] = logits[name].transpose(1, 2).contiguous() # transpose + + def infer(self, x: torch.Tensor): + + logits = {} + x = self.forward_backbone(x) + + if self.fpn: + x = self.fpn(x) + self.fpn_predict(x, logits) + + if self.selector: + selects = self.selector(x, logits) + + if self.combiner: + comb_outs = self.combiner(selects) + logits['comb_outs'] = comb_outs + return logits + + if self.selector or self.fpn: + return logits + + ### original backbone (only predict final selected layer) + for name in x: + hs = x[name] + + if len(hs.size()) == 4: + hs = F.adaptive_avg_pool2d(hs, (1, 1)) + hs = hs.flatten(1) + else: + hs = hs.mean(1) + out = self.classifier(hs) + logits['ori_out'] = logits + + return logits + + def forward(self, x, target=None): + + logits = self.infer(x) + + if not self.training: + return logits + else: + + losses = list() + batch_size = x.shape[0] + device = x.device + for name in logits: + + if "select_" in name: + if not self.use_selection: + raise ValueError("Selector not use here.") + if self.lambda_s != 0: + S = logits[name].size(1) + logit = logits[name].view(-1, self.num_classes).contiguous() + loss_s = nn.CrossEntropyLoss()(logit.float(), target.unsqueeze(1).repeat(1, S).flatten(0)) + losses.append(LossItem(name="loss_s", value=loss_s, weight=self.lambda_s)) + + elif "drop_" in name: + if not self.use_selection: + raise ValueError("Selector not use here.") + + if self.lambda_n != 0: + S = logits[name].size(1) + logit = logits[name].view(-1, self.num_classes).contiguous() + n_preds = nn.Tanh()(logit) + labels_0 = (torch.zeros([batch_size * S, self.num_classes]) - 1).to(device) + loss_n = nn.MSELoss()(n_preds.float(), labels_0) + losses.append(LossItem(name="loss_n", value=loss_n, weight=self.lambda_n)) + + + elif "layer" in name: + if not self.use_fpn: + raise ValueError("FPN not use here.") + if self.lambda_b != 0: + ### here using 'layer1'~'layer4' is default setting, you can change to your own + loss_b = nn.CrossEntropyLoss()(logits[name].mean(1).float(), target) + losses.append(LossItem(name="loss_b", value=loss_b, weight=self.lambda_b)) + + elif "comb_outs" in name: + if not self.use_combiner: + raise ValueError("Combiner not use here.") + + if self.lambda_c != 0: + loss_c = nn.CrossEntropyLoss()(logits[name].float(), target) + losses.append(LossItem(name="loss_c", value=loss_c, weight=self.lambda_c)) + + elif "ori_out" in name: + loss_ori = F.cross_entropy(logits[name].float(), target) + losses.append(LossItem(name="loss_ori", value=loss_ori, weight=1.0)) + + return logits, losses \ No newline at end of file diff --git a/fgvclib/models/sotas/transFG.py b/fgvclib/models/sotas/transFG.py new file mode 100644 index 0000000..7e284c7 --- /dev/null +++ b/fgvclib/models/sotas/transFG.py @@ -0,0 +1,130 @@ +# coding=utf-8 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch.nn.functional as F +import numpy as np +import torch +import torch.nn as nn +import os.path as op +from scipy import ndimage +from torch.nn import Linear +from yacs.config import CfgNode + +from fgvclib.criterions import LossItem +from fgvclib.models.sotas.sota import FGVCSOTA +from fgvclib.models.sotas import fgvcmodel + + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +@fgvcmodel("TransFG") +class TransFG(FGVCSOTA): + def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): + super().__init__(cfg, backbone, encoder, necks, heads, criterions) + + self.num_classes = cfg.CLASS_NUM + self.smoothing_value = self.args["smoothing_value"] + self.zero_head = self.args["zero_head"] + self.classifier = self.args["classifier"] + self.part_head = Linear(self.args["part_head_in_channels"], self.num_classes) + self.pretrained_weight = self.args["pretrained_weight"] + + if op.exists(self.pretrained_weight): + print(f"Loading pretraining weight in {self.pretrained_weight}.") + if self.pretrained_weight.endswith('.npz'): + self.load_from(np.load(self.pretrained_weight)) + else: + self.load_state_dict(torch.load(self.pretrained_weight)) + + + def forward(self, x, labels=None): + part_tokens = self.encoder(x) + part_logits = self.part_head(part_tokens[:, 0]) + if labels is not None: + losses = list() + if self.smoothing_value == 0: + part_loss = self.criterions['cross_entropy_loss']['fn'](part_logits.view(-1, self.num_classes), + labels.view(-1)) + else: + part_loss = self.criterions['nll_loss_labelsmoothing']['fn'](part_logits.view(-1, self.num_classes), + labels.view(-1)) + contrast_loss = con_loss(part_tokens[:, 0], labels.view(-1)) + losses.append(LossItem(name='contrast_loss', value=contrast_loss, weight=1.0)) + losses.append(LossItem(name='part_loss', value=part_loss, weight=1.0)) + return part_logits, losses + else: + return part_logits + + def load_from(self, weights): + with torch.no_grad(): + self.encoder.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) + self.encoder.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) + self.encoder.embeddings.cls_token.copy_(np2th(weights["cls"])) + self.encoder.encoder.part_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) + self.encoder.encoder.part_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) + + posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) + posemb_new = self.encoder.embeddings.position_embeddings + if posemb.size() == posemb_new.size(): + self.encoder.embeddings.position_embeddings.copy_(posemb) + else: + + ntok_new = posemb_new.size(1) + + if self.classifier == "token": + posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] + ntok_new -= 1 + else: + posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + + gs_old = int(np.sqrt(len(posemb_grid))) + gs_new = int(np.sqrt(ntok_new)) + print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) + posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) + + zoom = (gs_new / gs_old, gs_new / gs_old, 1) + posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) + posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) + posemb = np.concatenate([posemb_tok, posemb_grid], axis=1) + self.encoder.embeddings.position_embeddings.copy_(np2th(posemb)) + + for bname, block in self.encoder.encoder.named_children(): + if bname.startswith('part') == False: + for uname, unit in block.named_children(): + unit.load_from(weights, n_block=uname) + + if self.encoder.embeddings.hybrid: + self.encoder.embeddings.hybrid_model.root.conv.weight.copy_( + np2th(weights["conv_root/kernel"], conv=True)) + gn_weight = np2th(weights["gn_root/scale"]).view(-1) + gn_bias = np2th(weights["gn_root/bias"]).view(-1) + self.encoder.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) + self.encoder.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) + + for bname, block in self.encoder.embeddings.hybrid_model.body.named_children(): + for uname, unit in block.named_children(): + unit.load_from(weights, n_block=bname, n_unit=uname) + + + + +def con_loss(features, labels): + B, _ = features.shape + features = F.normalize(features) + cos_matrix = features.mm(features.t()) + pos_label_matrix = torch.stack([labels == labels[i] for i in range(B)]).float() + neg_label_matrix = 1 - pos_label_matrix + pos_cos_matrix = 1 - cos_matrix + neg_cos_matrix = cos_matrix - 0.4 + neg_cos_matrix[neg_cos_matrix < 0] = 0 + loss = (pos_cos_matrix * pos_label_matrix).sum() + (neg_cos_matrix * neg_label_matrix).sum() + loss /= (B * B) + return loss diff --git a/fgvclib/optimizers/__init__.py b/fgvclib/optimizers/__init__.py new file mode 100644 index 0000000..7fa56e2 --- /dev/null +++ b/fgvclib/optimizers/__init__.py @@ -0,0 +1,38 @@ +import os +import importlib + + +__OPTMIZER_DICT__ = {} + + +def get_optimizer(name): + r"""Return the metric with the given name. + + Args: + name (str): + The name of metric. + + Return: + The metric contructor method. + """ + + return __OPTMIZER_DICT__[name] + +def get_optimizer(name): + return __OPTMIZER_DICT__[name] + +def optimizer(name): + + def register_function_fn(cls): + if name in __OPTMIZER_DICT__: + raise ValueError("Name %s already registered!" % name) + __OPTMIZER_DICT__[name] = cls + return cls + + return register_function_fn + +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith('.py') and not file.startswith('_'): + module_name = file[:file.find('.py')] + module = importlib.import_module('fgvclib.optimizers.' + module_name) + diff --git a/fgvclib/optimizers/adam.py b/fgvclib/optimizers/adam.py new file mode 100644 index 0000000..f998c18 --- /dev/null +++ b/fgvclib/optimizers/adam.py @@ -0,0 +1,9 @@ +from torch.optim import Adam + + +from .import optimizer + +@optimizer("Adam") +def adam(params, lr, cfg): + return Adam(params=params, lr=lr, **cfg) + diff --git a/fgvclib/optimizers/adamw.py b/fgvclib/optimizers/adamw.py new file mode 100644 index 0000000..8e6b069 --- /dev/null +++ b/fgvclib/optimizers/adamw.py @@ -0,0 +1,8 @@ +from torch.optim import AdamW + + +from .import optimizer + +@optimizer("AdamW") +def adamw(params, lr, cfg): + return AdamW(params=params, lr=lr, **cfg) \ No newline at end of file diff --git a/fgvclib/optimizers/sgd.py b/fgvclib/optimizers/sgd.py new file mode 100644 index 0000000..6d31d8d --- /dev/null +++ b/fgvclib/optimizers/sgd.py @@ -0,0 +1,8 @@ +from torch.optim import SGD + + +from .import optimizer + +@optimizer("SGD") +def sgd(params, lr, cfg): + return SGD(params=params, lr=lr, **cfg) \ No newline at end of file diff --git a/fgvclib/transforms/base_transforms.py b/fgvclib/transforms/base_transforms.py index 06317bc..6981617 100644 --- a/fgvclib/transforms/base_transforms.py +++ b/fgvclib/transforms/base_transforms.py @@ -30,3 +30,11 @@ def normalize(cfg: dict): @transform("color_jitter") def color_jitter(cfg: dict): return transforms.ColorJitter(brightness=cfg['brightness'], saturation=cfg['saturation']) + +@transform('randomApply_gaussianBlur') +def randomApply_gaussianBlur(cfg: dict): + return transforms.RandomApply([transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 5))], p=cfg['prob']) + +@transform('randomAdjust_sharpness') +def randomAdjust_sharpness(cfg:dict): + return transforms.RandomAdjustSharpness(sharpness_factor=cfg['sharpness_factor'], p=cfg['prob']) diff --git a/fgvclib/utils/evaluate_function/__init__.py b/fgvclib/utils/evaluate_function/__init__.py new file mode 100644 index 0000000..c110e79 --- /dev/null +++ b/fgvclib/utils/evaluate_function/__init__.py @@ -0,0 +1,32 @@ +import os +import importlib + +__EVAL_FN_DICT__ = {} + + +def get_evaluate_function(name): + r"""Return the evaluate function with the given name. + Args: + name (str): + The name of evaluate function. + Return: + (function): The evaluate function. + """ + + return __EVAL_FN_DICT__[name] + +def evaluate_function(name): + + def register_function_fn(cls): + if name in __EVAL_FN_DICT__: + raise ValueError("Name %s already registered!" % name) + __EVAL_FN_DICT__[name] = cls + return cls + + return register_function_fn + +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith('.py') and not file.startswith('_'): + module_name = file[:file.find('.py')] + module = importlib.import_module('fgvclib.utils.evaluate_function.' + module_name) + \ No newline at end of file diff --git a/fgvclib/utils/evaluate_function/evaluate_model.py b/fgvclib/utils/evaluate_function/evaluate_model.py new file mode 100644 index 0000000..9dc6fb5 --- /dev/null +++ b/fgvclib/utils/evaluate_function/evaluate_model.py @@ -0,0 +1,128 @@ +# Copyright (c) 2022-present, BUPT-PRIS. + +""" + This file provides a api for evaluating FGVC algorithms. +""" + + +import torch +import torch.nn as nn +from torch.autograd import Variable +import typing as t + +from fgvclib.metrics.metrics import NamedMetric +from . import evaluate_function + +@evaluate_function("general_evaluate") +def general_evaluate(model:nn.Module, p_bar:t.Iterable, metrics:t.List[NamedMetric], use_cuda:bool=True) -> t.Dict: + r"""Evaluate the FGVC model. + + Args: + model (nn.Module): + The FGVC model. + p_bar (iterable): + A iterator provide test data. + metrics (List[NamedMetric]): + List of metrics. + use_cuda (boolean, optional): + Whether to use gpu. + + Returns: + dict: The result dict. + """ + + model.eval() + results = dict() + + with torch.no_grad(): + for _, (inputs, targets) in enumerate(p_bar): + if use_cuda: + inputs, targets = inputs.cuda(), targets.cuda() + inputs, targets = Variable(inputs), Variable(targets) + logits = model(inputs) + for metric in metrics: + _ = metric.update(logits, targets) + + for metric in metrics: + result = metric.compute() + results.update({ + metric.name: round(result.item(), 3) + }) + + return results + +@evaluate_function("swin_transformer_evaluate") +def swin_transformer_evaluate(model:nn.Module, p_bar:t.Iterable, metrics:t.List[NamedMetric], use_cuda:bool=True) -> t.Dict: + r"""Evaluate the FGVC model. + + Args: + model (nn.Module): + The FGVC model. + p_bar (iterable): + A iterator provide test data. + metrics (List[NamedMetric]): + List of metrics. + use_cuda (boolean, optional): + Whether to use gpu. + + Returns: + dict: The result dict. + """ + + model.eval() + results = dict() + + with torch.no_grad(): + for _, (inputs, targets) in enumerate(p_bar): + if use_cuda: + inputs, targets = inputs.cuda(), targets.cuda() + inputs, targets = Variable(inputs), Variable(targets) + logits = model(inputs) + for metric in metrics: + _ = metric.update(logits["comb_outs"], targets) + + for metric in metrics: + result = metric.compute() + results.update({ + metric.name: round(result.item(), 3) + }) + + return results + +@evaluate_function("wsdan_cal_evaluate") +def wsdan_cal_evaluate(model:nn.Module, p_bar:t.Iterable, metrics:t.List[NamedMetric], use_cuda:bool=True) -> t.Dict: + r"""Evaluate the FGVC model. + + Args: + model (nn.Module): + The FGVC model. + p_bar (iterable): + A iterator provide test data. + metrics (List[NamedMetric]): + List of metrics. + use_cuda (boolean, optional): + Whether to use gpu. + + Returns: + dict: The result dict. + """ + + model.eval() + results = dict() + + with torch.no_grad(): + for _, (inputs, targets) in enumerate(p_bar): + if use_cuda: + inputs, targets = inputs.cuda(), targets.cuda() + inputs, targets = Variable(inputs), Variable(targets) + logits = model(inputs) + for metric in metrics: + _ = metric.update(logits["comb_outs"], targets) + + for metric in metrics: + result = metric.compute() + results.update({ + metric.name: round(result.item(), 3) + }) + + return results diff --git a/fgvclib/utils/lr_schedules/__init__.py b/fgvclib/utils/lr_schedules/__init__.py index 6ed8cd4..cc8eb85 100644 --- a/fgvclib/utils/lr_schedules/__init__.py +++ b/fgvclib/utils/lr_schedules/__init__.py @@ -1,5 +1,6 @@ import os import importlib +from torch.optim.lr_scheduler import _LRScheduler from fgvclib.utils.lr_schedules.lr_schedule import LRSchedule @@ -8,12 +9,12 @@ def get_lr_schedule(name) -> LRSchedule: - r"""Return the dataset with the given name. + r"""Return the learning rate schedule with the given name. Args: name (str): - The name of dataset. + The name of learning rate schedule. Return: - (FGVCDataset): The dataset contructor method. + (LRSchedule): The learning rate schedule contructor. """ return __LR_SCHEDULE_DICT__[name] @@ -23,8 +24,8 @@ def lr_schedule(name): def register_function_fn(cls): if name in __LR_SCHEDULE_DICT__: raise ValueError("Name %s already registered!" % name) - if not issubclass(cls, LRSchedule): - raise ValueError("Class %s is not a subclass of %s" % (cls, LRSchedule)) + # if not issubclass(cls, LRSchedule) and not issubclass(cls, _LRScheduler): + # raise ValueError("Class %s is not a subclass of %s or %s" % (cls, LRSchedule, _LRScheduler)) __LR_SCHEDULE_DICT__[name] = cls return cls diff --git a/fgvclib/utils/lr_schedules/adjusting_schedule.py b/fgvclib/utils/lr_schedules/adjusting_schedule.py index a624574..523bce3 100644 --- a/fgvclib/utils/lr_schedules/adjusting_schedule.py +++ b/fgvclib/utils/lr_schedules/adjusting_schedule.py @@ -1,18 +1,24 @@ from .lr_schedule import LRSchedule from . import lr_schedule -@lr_schedule("adjusting_schedule") + class Adjusting_Schedule(LRSchedule): - def __init__(self, cfg) -> None: - super().__init__(cfg) - self.base_rate = cfg["base_rate"] - self.base_duration = cfg["base_duration"] - self.base_lr = cfg["base_lr"] - self.update_level = 'batch_update' + def __init__(self, optimizer, base_rate, base_duration, base_lr) -> None: + super().__init__(optimizer) + + self.base_rate = base_rate + self.base_duration = base_duration + self.base_lr = base_lr + self.update_level = "batch" - def step(self, optimizer, batch_idx, current_epoch, batch_size, **kwargs): - iter = float(batch_idx) / batch_size + def step(self, batch_idx, current_epoch, total_batch, **kwargs): + iter = float(batch_idx) / total_batch lr = self.base_lr * pow(self.base_rate, (current_epoch + iter) / self.base_duration) - for pg in optimizer.param_groups: + for pg in self.optimizer.param_groups: pg['lr'] = lr + + +@lr_schedule("adjusting_schedule") +def adjusting_schedule(optimizer, batch_num_per_epoch, cfg:dict): + return Adjusting_Schedule(optimizer, **cfg) diff --git a/fgvclib/utils/lr_schedules/constant_schedule.py b/fgvclib/utils/lr_schedules/constant_schedule.py new file mode 100644 index 0000000..6c2fdea --- /dev/null +++ b/fgvclib/utils/lr_schedules/constant_schedule.py @@ -0,0 +1,8 @@ +from torch.optim.lr_scheduler import LambdaLR + + +class ConstantSchedule(LambdaLR): + """ Constant learning rate schedule. + """ + def __init__(self, optimizer, total_epoch=-1): + super(ConstantSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=total_epoch) \ No newline at end of file diff --git a/fgvclib/utils/lr_schedules/cosine_anneal_schedule.py b/fgvclib/utils/lr_schedules/cosine_anneal_schedule.py index da135ee..db4af36 100644 --- a/fgvclib/utils/lr_schedules/cosine_anneal_schedule.py +++ b/fgvclib/utils/lr_schedules/cosine_anneal_schedule.py @@ -3,18 +3,22 @@ from .lr_schedule import LRSchedule from . import lr_schedule -@lr_schedule("cosine_anneal_schedule") + class CosineAnnealSchedule(LRSchedule): - def __init__(self, cfg) -> None: - super().__init__(cfg) - self.update_level = 'epoch_update' + def __init__(self, optimizer, **kwargs) -> None: + super().__init__(optimizer) + self.update_level = 'epoch' - def step(self, optimizer, current_epoch, total_epoch, **kwargs): + def step(self, current_epoch, total_epoch, **kwargs): cos_inner = np.pi * (current_epoch % (total_epoch)) cos_inner /= (total_epoch) cos_out = np.cos(cos_inner) + 1 - for pg in optimizer.param_groups: + for pg in self.optimizer.param_groups: current_lr = pg['lr'] pg['lr'] = float(current_lr / 2 * cos_out) + +@lr_schedule("cosine_anneal_schedule") +def cosine_anneal_schedule(optimizer, batch_num_per_epoch, cfg:dict): + return CosineAnnealSchedule(optimizer, **cfg) diff --git a/fgvclib/utils/lr_schedules/cosine_decay_schedule.py b/fgvclib/utils/lr_schedules/cosine_decay_schedule.py new file mode 100644 index 0000000..c51c55e --- /dev/null +++ b/fgvclib/utils/lr_schedules/cosine_decay_schedule.py @@ -0,0 +1,38 @@ +import math +import numpy as np + +from . import lr_schedule +from .lr_schedule import LRSchedule + +class WarmupCosineDecaySchedule(LRSchedule): + + def __init__(self, optimizer, max_epochs, warmup_steps, max_lr, batch_num_per_epoch, decay_type) -> None: + super().__init__(optimizer) + + total_batchs = max_epochs * batch_num_per_epoch + iters = np.arange(total_batchs - warmup_steps) + + self.update_level = 'batch' + + if decay_type == 1: + schedule = np.array([1e-12 + 0.5 * (max_lr - 1e-12) * (1 + math.cos(math.pi * t / total_batchs)) for t in iters]) + elif decay_type == 2: + schedule = max_lr * np.array([math.cos(7 * math.pi * t / (16 * total_batchs)) for t in iters]) + else: + raise ValueError("Not support this decay type") + + if warmup_steps > 0: + warmup_lr_schedule = np.linspace(1e-9, max_lr, warmup_steps) + schedule = np.concatenate((warmup_lr_schedule, schedule)) + + self.schedule = schedule + + def step(self, iteration): + + for pg in self.optimizer.param_groups: + pg['lr'] = self.schedule[iteration] + + +@lr_schedule("warmup_cosine_decay_schedule") +def warmup_cosine_decay_schedule(optimizer, batch_num_per_epoch, cfg:dict): + return WarmupCosineDecaySchedule(optimizer=optimizer, batch_num_per_epoch=batch_num_per_epoch, **cfg) diff --git a/fgvclib/utils/lr_schedules/lr_schedule.py b/fgvclib/utils/lr_schedules/lr_schedule.py index d995b8e..f31557b 100644 --- a/fgvclib/utils/lr_schedules/lr_schedule.py +++ b/fgvclib/utils/lr_schedules/lr_schedule.py @@ -1,7 +1,7 @@ class LRSchedule: - def __init__(self, cfg) -> None: - self.cfg = cfg + def __init__(self, optimizer) -> None: + self.optimizer = optimizer - def step(**kwargs): + def step(self): raise NotImplementedError("Eacbh subclass of LRSchedule should implemented the step method.") diff --git a/fgvclib/utils/lr_schedules/warmup_cosine_schedule.py b/fgvclib/utils/lr_schedules/warmup_cosine_schedule.py new file mode 100644 index 0000000..71fd187 --- /dev/null +++ b/fgvclib/utils/lr_schedules/warmup_cosine_schedule.py @@ -0,0 +1,28 @@ +from torch.optim.lr_scheduler import LambdaLR + +from . import lr_schedule + + + +class WarmupLinearSchedule(LambdaLR): + + + def __init__(self, optimizer, warmup_steps, total_steps, last_epoch=-1): + self.warmup_steps = warmup_steps + self.total_steps = total_steps + super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) + self.update_level = "batch" + + def lr_lambda(self, step): + if step < self.warmup_steps: + return float(step) / float(max(1, self.warmup_steps)) + return max(0.0, float(self.total_steps - step) / float(max(1.0, self.total_steps - self.warmup_steps))) + + def step(self, **kwargs): + return super().step() + + +@lr_schedule("warmup_linear_schedule") +def warmup_linear_schedule(optimizer, batch_num_per_epoch, cfg:dict): + return WarmupLinearSchedule(optimizer=optimizer, **cfg) + diff --git a/fgvclib/utils/update_function/__init__.py b/fgvclib/utils/update_function/__init__.py new file mode 100644 index 0000000..b5c59c6 --- /dev/null +++ b/fgvclib/utils/update_function/__init__.py @@ -0,0 +1,32 @@ +import os +import importlib + + +__UPDATE_FN_DICT__ = {} + + +def get_update_function(name): + r"""Return the update model function with the given name. + Args: + name (str): + The name of update function. + Return: + (function): The update function. + """ + + return __UPDATE_FN_DICT__[name] + +def update_function(name): + + def register_function_fn(cls): + if name in __UPDATE_FN_DICT__: + raise ValueError("Name %s already registered!" % name) + __UPDATE_FN_DICT__[name] = cls + return cls + + return register_function_fn + +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith('.py') and not file.startswith('_'): + module_name = file[:file.find('.py')] + module = importlib.import_module('fgvclib.utils.update_function.' + module_name) \ No newline at end of file diff --git a/fgvclib/apis/update_model.py b/fgvclib/utils/update_function/general_update.py similarity index 68% rename from fgvclib/apis/update_model.py rename to fgvclib/utils/update_function/general_update.py index b963144..b6f695a 100644 --- a/fgvclib/apis/update_model.py +++ b/fgvclib/utils/update_function/general_update.py @@ -1,16 +1,20 @@ from typing import Iterable -import torch.nn as nn from torch.optim import Optimizer +import torch.nn as nn +from . import update_function from fgvclib.utils.update_strategy import get_update_strategy from fgvclib.utils.logger import Logger from fgvclib.utils.lr_schedules import LRSchedule -def update_model( + + +@update_function("general_update") +def general_update( model: nn.Module, optimizer: Optimizer, pbar:Iterable, lr_schedule:LRSchedule=None, strategy:str="general_updating", use_cuda:bool=True, logger:Logger=None, - epoch:int=None, total_epoch:int=None, **kwargs + epoch:int=None, total_epoch:int=None, amp:bool=False, **kwargs ): r"""Update the FGVC model and record losses. @@ -25,19 +29,23 @@ def update_model( epoch (int): The current epoch number. total_epoch (int): The total epoch number. """ + model.train() mean_loss = 0. for batch_idx, train_data in enumerate(pbar): - losses_info = get_update_strategy(strategy)(model, train_data, optimizer, use_cuda) + losses_info = get_update_strategy(strategy)(model=model, train_data=train_data, optimizer=optimizer, use_cuda=use_cuda, amp=amp) mean_loss = (mean_loss * batch_idx + losses_info['iter_loss']) / (batch_idx + 1) losses_info.update({"mean_loss": mean_loss}) logger(losses_info, step=batch_idx) pbar.set_postfix(losses_info) - if lr_schedule.update_level == 'batch_update': - lr_schedule.step(optimizer=optimizer, batch_idx=batch_idx, batch_size=len(train_data), current_epoch=epoch, total_epoch=total_epoch) + if lr_schedule.update_level == 'batch': + lr_schedule.step(batch_idx=batch_idx, total_batch=len(pbar), current_epoch=epoch, total_epoch=total_epoch) - if lr_schedule.update_level == 'epoch_update': - lr_schedule.step(optimizer=optimizer, current_epoch=epoch, total_epoch=total_epoch) + if lr_schedule.update_level == 'epoch': + lr_schedule.step(current_epoch=epoch, total_epoch=total_epoch) + + + diff --git a/fgvclib/utils/update_function/update_swin_transformer.py b/fgvclib/utils/update_function/update_swin_transformer.py new file mode 100644 index 0000000..1534b0c --- /dev/null +++ b/fgvclib/utils/update_function/update_swin_transformer.py @@ -0,0 +1,74 @@ +from typing import Iterable +from torch.autograd import Variable +from torch.optim import Optimizer +from torch.cuda.amp import autocast, GradScaler +import torch.nn as nn +import torch +import torch.functional as F + +from . import update_function +from fgvclib.utils.logger import Logger +from fgvclib.criterions import compute_loss_value, detach_loss_value + + + +@update_function("update_swin_transformer") +def update_swin_transformer(model: nn.Module, optimizer: Optimizer, pbar:Iterable, lr_schedule=None, + strategy:str="update_swint", use_cuda:bool=True, logger:Logger=None, + epoch:int=None, total_epoch:int=None, amp:bool=False, use_selection=False, cfg=None, **kwargs, +): + scaler = GradScaler() + + optimizer.zero_grad() + + for batch_id, (inputs, targets) in enumerate(pbar): + model.train() + if lr_schedule.update_level == "batch": + iteration = epoch * len(pbar) + batch_id + lr_schedule.step(iteration) + + if lr_schedule.update_level == "epoch": + lr_schedule.step() + + if use_cuda: + inputs, targets = inputs.cuda(), targets.cuda() + inputs, targets = Variable(inputs), Variable(targets) + + with autocast(): + """ + [Model Return] + FPN + Selector + Combiner --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1', 'comb_outs' + FPN + Selector --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1' + FPN --> return 'layer1', 'layer2', 'layer3', 'layer4' (depend on your setting) + ~ --> return 'ori_out' + + [Retuen Tensor] + 'preds_0': logit has not been selected by Selector. + 'preds_1': logit has been selected by Selector. + 'comb_outs': The prediction of combiner. + """ + out, losses = model(inputs, targets) + total_loss = compute_loss_value(losses) + if hasattr(model, 'update_freq'): + update_freq = model.update_freq + else: + update_freq = model.module.update_freq + total_loss /= update_freq + + if amp: + scaler.scale(total_loss).backward() + else: + total_loss.backward() + + losses_info = detach_loss_value(losses) + losses_info.update({"iter_loss": total_loss.item()}) + pbar.set_postfix(losses_info) + if (batch_id + 1) % update_freq == 0: + if amp: + scaler.step(optimizer) + scaler.update() # next batch + else: + optimizer.step() + optimizer.zero_grad() diff --git a/fgvclib/utils/update_strategy/__init__.py b/fgvclib/utils/update_strategy/__init__.py index 66fc984..8d2ce40 100644 --- a/fgvclib/utils/update_strategy/__init__.py +++ b/fgvclib/utils/update_strategy/__init__.py @@ -1,22 +1,32 @@ -from .progressive_updating_with_jigsaw import progressive_updating_with_jigsaw -from .progressive_updating_consistency_constraint import progressive_updating_consistency_constraint -from .general_updating import general_updating +import os +import importlib -__all__ = [ - 'progressive_updating_with_jigsaw', 'progressive_updating_consistency_constraint', 'general_updating' -] -def get_update_strategy(strategy_name): - r"""Return the learning rate schedule with the given name. +__UPDATE_ST_DICT__ = {} + +def get_update_strategy(name): + r"""Return the update strategy with the given name. Args: - strategy_name (str): - The name of the update strategy. - + name (str): + The name of update strategy. Return: - The update strategy contructor method. + (function): The update strategy function. """ + + return __UPDATE_ST_DICT__[name] + +def update_strategy(name): + + def register_function_fn(cls): + if name in __UPDATE_ST_DICT__: + raise ValueError("Name %s already registered!" % name) + __UPDATE_ST_DICT__[name] = cls + return cls + + return register_function_fn - if strategy_name not in globals(): - raise NotImplementedError(f"Strategy not found: {strategy_name}\nAvailable strategy: {__all__}") - return globals()[strategy_name] +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith('.py') and not file.startswith('_'): + module_name = file[:file.find('.py')] + module = importlib.import_module('fgvclib.utils.update_strategy.' + module_name) \ No newline at end of file diff --git a/fgvclib/utils/update_strategy/general_strategy.py b/fgvclib/utils/update_strategy/general_strategy.py new file mode 100644 index 0000000..30c54ef --- /dev/null +++ b/fgvclib/utils/update_strategy/general_strategy.py @@ -0,0 +1,35 @@ +from torch.autograd import Variable +import typing as t +from torch.cuda.amp import autocast, GradScaler + +from . import update_strategy +from fgvclib.criterions import compute_loss_value, detach_loss_value + + + +@update_strategy("general_strategy") +def general_strategy(model, train_data, optimizer, use_cuda=True, amp=False, **kwargs) -> t.Dict: + if amp: + scaler = GradScaler() + inputs, targets = train_data + if use_cuda: + inputs, targets = inputs.cuda(), targets.cuda() + inputs, targets = Variable(inputs), Variable(targets) + if amp: + with autocast(): + out, losses = model(inputs, targets) + total_loss = compute_loss_value(losses) + scaler.scale(total_loss).backward() + scaler.step(optimizer) + scaler.update() + else: + out, losses = model(inputs, targets) + total_loss = compute_loss_value(losses) + total_loss.backward() + optimizer.step() + + optimizer.zero_grad() + losses_info = detach_loss_value(losses) + losses_info.update({"iter_loss": total_loss.item()}) + + return losses_info \ No newline at end of file diff --git a/fgvclib/utils/update_strategy/progressive_updating_consistency_constraint.py b/fgvclib/utils/update_strategy/progressive_consistency_constraint.py similarity index 93% rename from fgvclib/utils/update_strategy/progressive_updating_consistency_constraint.py rename to fgvclib/utils/update_strategy/progressive_consistency_constraint.py index b06f151..3ff53f4 100644 --- a/fgvclib/utils/update_strategy/progressive_updating_consistency_constraint.py +++ b/fgvclib/utils/update_strategy/progressive_consistency_constraint.py @@ -3,11 +3,11 @@ from torch import nn, Tensor from torch.autograd import Variable +from . import update_strategy from fgvclib.criterions import compute_loss_value, detach_loss_value -BLOCKS = [[8, 8, 0, 0], [4, 4, 4, 0], [2, 2, 2, 2]] -alpha = [0.01, 0.05, 0.1] +@update_strategy("progressive_updating_consistency_constraint") def progressive_updating_consistency_constraint(model:nn.Module, train_data:t.Tuple[Tensor, Tensor, Tensor], optimizer, use_cuda=True, **kwargs) -> t.Dict: inputs, positive_inputs, targets = train_data batch_size = inputs.size(0) diff --git a/fgvclib/utils/update_strategy/progressive_updating_with_jigsaw.py b/fgvclib/utils/update_strategy/progressive_updating_with_jigsaw.py index c5e13b6..8eb268b 100644 --- a/fgvclib/utils/update_strategy/progressive_updating_with_jigsaw.py +++ b/fgvclib/utils/update_strategy/progressive_updating_with_jigsaw.py @@ -1,8 +1,11 @@ import random from torch.autograd import Variable +from . import update_strategy from fgvclib.criterions import compute_loss_value, detach_loss_value + +@update_strategy("progressive_updating_with_jigsaw") def progressive_updating_with_jigsaw(model, train_data, optimizer, use_cuda=True, **kwargs): inputs, targets = train_data if use_cuda: diff --git a/fgvclib/utils/update_strategy/general_updating.py b/fgvclib/utils/update_strategy/vit_updating.py similarity index 60% rename from fgvclib/utils/update_strategy/general_updating.py rename to fgvclib/utils/update_strategy/vit_updating.py index be49923..4004c87 100644 --- a/fgvclib/utils/update_strategy/general_updating.py +++ b/fgvclib/utils/update_strategy/vit_updating.py @@ -1,20 +1,27 @@ -from torch.autograd import Variable -import typing as t - -from fgvclib.criterions import compute_loss_value, detach_loss_value - - -def general_updating(model, train_data, optimizer, use_cuda=True, **kwargs) -> t.Dict: - inputs, targets = train_data - if use_cuda: - inputs, targets = inputs.cuda(), targets.cuda() - inputs, targets = Variable(inputs), Variable(targets) - out, losses = model(inputs, targets) - total_loss = compute_loss_value(losses) - total_loss.backward() - optimizer.step() - optimizer.zero_grad() - losses_info = detach_loss_value(losses) - losses_info.update({"iter_loss": total_loss.item()}) - - return losses_info \ No newline at end of file +import numpy as np +import torch +from torch.autograd import Variable +import typing as t + +from ..update_strategy import update_strategy +from fgvclib.criterions import compute_loss_value, detach_loss_value + + +@update_strategy("vit_update_strategy") +def vit_update_strategy(model, train_data, optimizer, use_cuda=True, amp=False, **kwargs) -> t.Dict: + + inputs, targets = train_data + if use_cuda: + inputs, targets = inputs.cuda(), targets.cuda() + inputs, targets = Variable(inputs), Variable(targets) + out, losses = model(inputs, targets) + total_loss = compute_loss_value(losses) + total_loss = total_loss.mean() + total_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + optimizer.zero_grad() + losses_info = detach_loss_value(losses) + losses_info.update({"iter_loss": total_loss.item()}) + + return losses_info diff --git a/main.py b/main.py index 765794c..c2c9720 100644 --- a/main.py +++ b/main.py @@ -7,7 +7,7 @@ from fgvclib.apis import * from fgvclib.configs import FGVCConfig -from fgvclib.utils import cosine_anneal_schedule, init_distributed_mode +from fgvclib.utils import init_distributed_mode def train(cfg: CfgNode): @@ -49,17 +49,16 @@ def train(cfg: CfgNode): ) model.to(device) + sampler_cfg = cfg.SAMPLER if cfg.DISTRIBUTED: model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True, device_ids=[cfg.GPU]) train_sampler = torch.utils.data.distributed.DistributedSampler(train_set) test_sampler = torch.utils.data.distributed.DistributedSampler(test_set) + sampler_cfg.TRAIN.IS_BATCH_SAMPLER = False else: - sampler_cfg = cfg.SAMPLER train_sampler = build_sampler(sampler_cfg.TRAIN)(train_set, **tltd(sampler_cfg.TRAIN.ARGS)) test_sampler = build_sampler(sampler_cfg.TEST)(test_set, **tltd(sampler_cfg.TEST.ARGS)) - - train_loader = build_dataloader( dataset=train_set, mode_cfg=cfg.DATASET.TRAIN, @@ -71,7 +70,6 @@ def train(cfg: CfgNode): dataset=test_set, mode_cfg=cfg.DATASET.TEST, sampler=test_sampler, - is_batch_sampler=sampler_cfg.TEST.IS_BATCH_SAMPLER ) optimizer = build_optimizer(cfg.OPTIMIZER, model) @@ -80,7 +78,11 @@ def train(cfg: CfgNode): metrics = build_metrics(cfg.METRICS) - lr_schedule = build_lr_schedule(cfg.LR_SCHEDULE) + lr_schedule = build_lr_schedule(optimizer, cfg.LR_SCHEDULE, train_loader) + + update_fn = build_update_function(cfg) + + evaluate_fn = build_evaluate_function(cfg) for epoch in range(cfg.START_EPOCH, cfg.EPOCH_NUM): if args.distributed: @@ -90,12 +92,10 @@ def train(cfg: CfgNode): logger(f'Epoch: {epoch + 1} / {cfg.EPOCH_NUM} Training') - - - update_model( + update_fn( model, optimizer, train_bar, strategy=cfg.UPDATE_STRATEGY, use_cuda=cfg.USE_CUDA, lr_schedule=lr_schedule, - logger=logger, epoch=epoch, total_epoch=cfg.EPOCH_NUM + logger=logger, epoch=epoch, total_epoch=cfg.EPOCH_NUM, amp=cfg.AMP ) test_bar = tqdm(test_loader) @@ -103,7 +103,7 @@ def train(cfg: CfgNode): logger(f'Epoch: {epoch + 1} / {cfg.EPOCH_NUM} Testing ') - acc = evaluate_model(model, test_bar, metrics=metrics, use_cuda=cfg.USE_CUDA) + acc = evaluate_fn(model, test_bar, metrics=metrics, use_cuda=cfg.USE_CUDA) print(acc) logger("Evalution Result:") logger(acc) @@ -112,7 +112,7 @@ def train(cfg: CfgNode): model_with_ddp = model.module else: model_with_ddp = model - save_model(cfg=cfg, model=model, logger=logger) + save_model(cfg=cfg, model=model_with_ddp, logger=logger) logger.finish() def predict(cfg: CfgNode): @@ -137,7 +137,8 @@ def predict(cfg: CfgNode): pbar = tqdm(loader) metrics = build_metrics(cfg.METRICS) - acc = evaluate_model(model, pbar, metrics=metrics, use_cuda=cfg.USE_CUDA) + evaluate_fn = build_evaluate_function(cfg) + acc = evaluate_fn(model, pbar, metrics=metrics, use_cuda=cfg.USE_CUDA) print(acc) @@ -157,6 +158,7 @@ def predict(cfg: CfgNode): config = FGVCConfig() config.load(args.config) cfg = config.cfg + set_seed(cfg.SEED) if args.distributed: cfg.DISTRIBUTED = args.distributed cfg.GPU = args.gpu