From a7f93d6ceba4c9cbe80df98f37c77d3d01f8515d Mon Sep 17 00:00:00 2001 From: xin-ran-w Date: Mon, 16 Jan 2023 15:55:39 +0800 Subject: [PATCH 01/10] update support amp training --- configs/swin_transformer/swin_transformer.yml | 162 ++++++++++++++ configs/transfg/transFG_ViT_B_16.yml | 161 ++++++++++++++ fgvclib/apis/__init__.py | 2 +- fgvclib/apis/build.py | 6 +- fgvclib/apis/seed.py | 11 + fgvclib/apis/update_model.py | 44 ---- fgvclib/configs/config.py | 6 +- fgvclib/models/backbones/swin_t.py | 36 ++++ fgvclib/models/backbones/vit.py | 80 +++++++ fgvclib/models/encoders/fpn.py | 97 +++++++++ .../models/encoders/transformer_encoder.py | 203 ++++++++++++++++++ fgvclib/models/heads/gcn_combiner.py | 102 +++++++++ fgvclib/models/heads/mlp.py | 49 +++++ fgvclib/models/necks/weakly_selector.py | 83 +++++++ fgvclib/models/sotas/api_net.py | 4 +- fgvclib/models/sotas/cal.py | 12 +- fgvclib/models/sotas/mcl.py | 9 +- fgvclib/models/sotas/pmg.py | 5 +- fgvclib/models/sotas/resnet50.py | 5 +- fgvclib/models/sotas/sota.py | 11 +- fgvclib/models/sotas/swin_T.py | 104 +++++++++ fgvclib/models/sotas/transFG.py | 130 +++++++++++ .../utils/lr_schedules/constant_schedule.py | 10 + .../lr_schedules/cosine_decay_schedule.py | 33 +++ fgvclib/utils/update_function/__init__.py | 34 +++ .../utils/update_function/general_update.py | 189 ++++++++++++++++ .../utils/update_strategy/general_updating.py | 22 +- fgvclib/utils/update_strategy/vit_updating.py | 24 +++ main.py | 12 +- 29 files changed, 1569 insertions(+), 77 deletions(-) create mode 100644 configs/swin_transformer/swin_transformer.yml create mode 100644 configs/transfg/transFG_ViT_B_16.yml create mode 100644 fgvclib/apis/seed.py delete mode 100644 fgvclib/apis/update_model.py create mode 100644 fgvclib/models/backbones/swin_t.py create mode 100644 fgvclib/models/backbones/vit.py create mode 100644 fgvclib/models/encoders/fpn.py create mode 100644 fgvclib/models/encoders/transformer_encoder.py create mode 100644 fgvclib/models/heads/gcn_combiner.py create mode 100644 fgvclib/models/heads/mlp.py create mode 100644 fgvclib/models/necks/weakly_selector.py create mode 100644 fgvclib/models/sotas/swin_T.py create mode 100644 fgvclib/models/sotas/transFG.py create mode 100644 fgvclib/utils/lr_schedules/constant_schedule.py create mode 100644 fgvclib/utils/lr_schedules/cosine_decay_schedule.py create mode 100644 fgvclib/utils/update_function/__init__.py create mode 100644 fgvclib/utils/update_function/general_update.py create mode 100644 fgvclib/utils/update_strategy/vit_updating.py diff --git a/configs/swin_transformer/swin_transformer.yml b/configs/swin_transformer/swin_transformer.yml new file mode 100644 index 0000000..ac252c3 --- /dev/null +++ b/configs/swin_transformer/swin_transformer.yml @@ -0,0 +1,162 @@ +EXP_NAME: "SwinT" + +RESUME_WEIGHT: ~ + +WEIGHT: + NAME: "swinT.pth" + SAVE_DIR: "/mnt/sdb/data/wangxinran/weight/fgvclib" + +LOGGER: + NAME: "txt_logger" + +DATASET: + NAME: "CUB_200_2011" + ROOT: "/mnt/sdb/data/wangxinran/dataset" + TRAIN: + BATCH_SIZE: 8 + POSITIVE: 0 + PIN_MEMORY: True + SHUFFLE: True + NUM_WORKERS: 2 + TEST: + BATCH_SIZE: 8 + POSITIVE: 0 + PIN_MEMORY: False + SHUFFLE: True + NUM_WORKERS: 1 + +MODEL: + NAME: "Swin_T" + CLASS_NUM: 200 + ARGS: + - img_size: 384 + + CRITERIONS: + - name: "cross_entropy_loss" + args: [] + w: 1.0 + - name: "mean_square_error_loss" + args: [] + w: 1.0 + BACKBONE: + NAME: "swint" + ARGS: + - pretrained: True + ENCODER: + NAME: "fpn" + ARGS: + - fpn_size: 1536 + - proj_type: "Linear" + - upsample_type: "Conv" + NECKS: + NAME: "weakly_selector" + ARGS: + - num_classes: 200 + - num_selects: + - layer1: 2048 + - layer2: 512 + - layer3: 128 + - layer4: 32 + - fpn_size: 1536 + HEADS: + NAME: "GCN_combiner" + ARGS: + - num_selects: + - layer1: 2048 + - layer2: 512 + - layer3: 128 + - layer4: 32 + - total_num_selects: 2720 + - num_classes: 200 + - fpn_size: 1536 + +TRANSFORMS: + TRAIN: + - name: "resize" + size: + - 510 + - 510 + - name: "random_crop" + size: 384 + padding: 0 + - name: "random_horizontal_flip" + prob: 0.5 + - name: "randomApply_gaussianBlur" + prob: 0.1 + - name: "randomAdjust_sharpness" + sharpness_factor: 1.5 + prob: 0.1 + - name: "to_tensor" + - name: "normalize" + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + TEST: + - name: "resize" + size: + - 510 + - 510 + - name: "center_crop" + size: 384 + - name: "to_tensor" + - name: "normalize" + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + +OPTIMIZER: + NAME: "SGD" + MOMENTUM: 0.9 + LR: + backbone: 0.005 + encoder: 0.005 + necks: 0.005 + heads: 0.005 + WEIGHT_DECAY: 0.0005 + +ITERATION_NUM: ~ +WARMUP_BATCHS: 800 +MAX_LR: 0.005 +EPOCH_NUM: 50 +START_EPOCH: 0 +USE_AMP: True +UPDATE_STRATEGY: ~ + + + +# Validation details +PER_ITERATION: ~ +PER_EPOCH: ~ +METRICS: + - name: "accuracy(topk=1)" + metric: "accuracy" + top_k: 1 + threshold: ~ + - name: "accuracy(topk=5)" + metric: "accuracy" + top_k: 5 + threshold: ~ + - name: "recall(threshold=0.5)" + metric: "recall" + top_k: ~ + threshold: 0.5 + - name: "precision(threshold=0.5)" + metric: "precision" + top_k: ~ + threshold: 0.5 + +INTERPRETER: + NAME: "cam" + METHOD: "gradcam" + TARGET_LAYERS: + - "layer4" \ No newline at end of file diff --git a/configs/transfg/transFG_ViT_B_16.yml b/configs/transfg/transFG_ViT_B_16.yml new file mode 100644 index 0000000..bf277a5 --- /dev/null +++ b/configs/transfg/transFG_ViT_B_16.yml @@ -0,0 +1,161 @@ +EXP_NAME: "TransFG" + +RESUME_WEIGHT: ~ + +WEIGHT: + NAME: "transFG_ViT_B16.pth" + SAVE_DIR: "/mnt/sdb/data/wangxinran/weight/fgvclib" + +LOGGER: + NAME: "txt_logger" + +DATASET: + ROOT: "/mnt/sdb/data/wangxinran/dataset" + NAME: "CUB_200_2011" + TRAIN: + BATCH_SIZE: 8 + POSITIVE: 0 + PIN_MEMORY: True + SHUFFLE: True + NUM_WORKERS: 4 + TEST: + BATCH_SIZE: 8 + POSITIVE: 0 + PIN_MEMORY: False + SHUFFLE: False + NUM_WORKERS: 4 + +MODEL: + NAME: "TransFG" + CLASS_NUM: 200 + ARGS: + - smoothing_value: 0 + - zero_head: True + - classifier: 'token' + - part_head_in_channels: 768 + - pretrained_weight: "/mnt/sdb/data/wangxinran/weight/pretraining/transFG/ViT-B_16.npz" + + CRITERIONS: + - name: "nll_loss_labelsmoothing" + args: + - smoothing_value: 0.0 + w: 1.0 + - name: "cross_entropy_loss" + args: [] + w: 1.0 + BACKBONE: + NAME: "vit16" + ARGS: + - patch_size: 16 + - image_size: 448 + - split: 'non-overlap' + - slide_step: 12 + - hidden_size: 768 + - representation_size: None + - dropout_rate: 0.1 + ENCODER: + NAME: "transformer_encoder" + ARGS: + - num_layers: 12 + - img_size: 448 + - num_heads: 12 + - attention_dropout_rate: 0.0 + - hidden_size: 768 + - mlp_dim: 3072 + - dropout_rate: 0.1 + - patch_size: 16 + - split: 'non-overlap' + - slide_step: 12 + - mlp_dim: 3072 + NECKS: + NAME: ~ + HEADS: + NAME: "mlp" + ARGS: + - hidden_size: 768 + - mlp_dim: 3072 + - dropout_rate: 0.1 + +TRANSFORMS: + TRAIN: + - name: "resize" + size: + - 600 + - 600 + - name: "random_crop" + size: 448 + padding: 0 + - name: "random_horizontal_flip" + prob: 0.5 + - name: "to_tensor" + - name: "normalize" + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + TEST: + - name: "resize" + size: + - 600 + - 600 + - name: "center_crop" + size: 448 + - name: "to_tensor" + - name: "normalize" + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + +OPTIMIZER: + NAME: SGD + MOMENTUM: 0.9 + WEIGHT_DECAY: 0.0 + LR: + backbone: 0.08 + encoder: 0.08 + necks: 0.08 + heads: 0.08 + +WARMUP_STEPS: 500 +NUM_STEPS: 10000 +ITERATION_NUM: ~ +EPOCH_NUM: 100 +START_EPOCH: 0 +UPDATE_STRATEGY: "vit_updating" + + +# Validation details +PER_ITERATION: ~ +PER_EPOCH: ~ +METRICS: + - name: "accuracy(topk=1)" + metric: "accuracy" + top_k: 1 + threshold: ~ + - name: "accuracy(topk=5)" + metric: "accuracy" + top_k: 5 + threshold: ~ + - name: "recall(threshold=0.5)" + metric: "recall" + top_k: ~ + threshold: 0.5 + - name: "precision(threshold=0.5)" + metric: "precision" + top_k: ~ + threshold: 0.5 + +INTERPRETER: + NAME: "cam" + METHOD: "gradcam" + TARGET_LAYERS: + - "layer4" \ No newline at end of file diff --git a/fgvclib/apis/__init__.py b/fgvclib/apis/__init__.py index 3d13266..d716eee 100644 --- a/fgvclib/apis/__init__.py +++ b/fgvclib/apis/__init__.py @@ -1,4 +1,4 @@ from .build import * from .evaluate_model import evaluate_model -from .update_model import update_model from .save_model import save_model +from .seed import set_seed diff --git a/fgvclib/apis/build.py b/fgvclib/apis/build.py index f3985af..2c4febf 100644 --- a/fgvclib/apis/build.py +++ b/fgvclib/apis/build.py @@ -16,9 +16,10 @@ from fgvclib.configs.utils import turn_list_to_dict as tltd from fgvclib.criterions import get_criterion from fgvclib.datasets import get_dataset -from fgvclib.samplers import get_sampler from fgvclib.datasets.datasets import FGVCDataset +from fgvclib.samplers import get_sampler from fgvclib.metrics import get_metric +from fgvclib.metrics.metrics import NamedMetric from fgvclib.models.sotas import get_model from fgvclib.models.sotas.sota import FGVCSOTA from fgvclib.models.backbones import get_backbone @@ -29,7 +30,6 @@ from fgvclib.utils.logger import get_logger, Logger from fgvclib.utils.interpreter import get_interpreter, Interpreter from fgvclib.utils.lr_schedules import get_lr_schedule, LRSchedule -from fgvclib.metrics.metrics import NamedMetric def build_model(model_cfg: CfgNode) -> FGVCSOTA: @@ -64,7 +64,7 @@ def build_model(model_cfg: CfgNode) -> FGVCSOTA: criterions.update({item["name"]: {"fn": build_criterion(item), "w": item["w"]}}) model_builder = get_model(model_cfg.NAME) - model = model_builder(backbone=backbone, encoder=encoder, necks=necks, heads=heads, criterions=criterions) + model = model_builder(cfg=model_cfg, backbone=backbone, encoder=encoder, necks=necks, heads=heads, criterions=criterions) return model diff --git a/fgvclib/apis/seed.py b/fgvclib/apis/seed.py new file mode 100644 index 0000000..1b5fa96 --- /dev/null +++ b/fgvclib/apis/seed.py @@ -0,0 +1,11 @@ +import random +import numpy as np +import torch + + +def set_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + diff --git a/fgvclib/apis/update_model.py b/fgvclib/apis/update_model.py deleted file mode 100644 index b963144..0000000 --- a/fgvclib/apis/update_model.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import Iterable -import torch.nn as nn -from torch.optim import Optimizer - - -from fgvclib.utils.update_strategy import get_update_strategy -from fgvclib.utils.logger import Logger -from fgvclib.utils.lr_schedules import LRSchedule - -def update_model( - model: nn.Module, optimizer: Optimizer, pbar:Iterable, lr_schedule:LRSchedule=None, - strategy:str="general_updating", use_cuda:bool=True, logger:Logger=None, - epoch:int=None, total_epoch:int=None, **kwargs -): - r"""Update the FGVC model and record losses. - - Args: - model (nn.Module): The FGVC model. - optimizer (Optimizer): The Logger object. - pbar (Iterable): The iterable object provide training data. - lr_schedule (LRSchedule): The lr schedule updating class. - strategy (string): The update strategy. - use_cuda (boolean): Whether to use GPU to train the model. - logger (Logger): The Logger object. - epoch (int): The current epoch number. - total_epoch (int): The total epoch number. - """ - model.train() - mean_loss = 0. - for batch_idx, train_data in enumerate(pbar): - losses_info = get_update_strategy(strategy)(model, train_data, optimizer, use_cuda) - mean_loss = (mean_loss * batch_idx + losses_info['iter_loss']) / (batch_idx + 1) - losses_info.update({"mean_loss": mean_loss}) - logger(losses_info, step=batch_idx) - pbar.set_postfix(losses_info) - if lr_schedule.update_level == 'batch_update': - lr_schedule.step(optimizer=optimizer, batch_idx=batch_idx, batch_size=len(train_data), current_epoch=epoch, total_epoch=total_epoch) - - if lr_schedule.update_level == 'epoch_update': - lr_schedule.step(optimizer=optimizer, current_epoch=epoch, total_epoch=total_epoch) - - - - \ No newline at end of file diff --git a/fgvclib/configs/config.py b/fgvclib/configs/config.py index 17e6778..e17fd23 100644 --- a/fgvclib/configs/config.py +++ b/fgvclib/configs/config.py @@ -17,6 +17,9 @@ def __init__(self): # Name of experiment self.cfg.EXP_NAME = None + # Random Seed + self.cfg.SEED = 0 + # Resume last train self.cfg.RESUME_WEIGHT = None @@ -57,7 +60,6 @@ def __init__(self): self.cfg.DATASET.TEST.SHUFFLE = False self.cfg.DATASET.TEST.NUM_WORKERS = 0 - # sampler for dataloader self.cfg.SAMPLER = CN() @@ -72,6 +74,7 @@ def __init__(self): self.cfg.SAMPLER.TEST.ARGS = None self.cfg.SAMPLER.TEST.IS_BATCH_SAMPLER = False + # Model architecture self.cfg.MODEL = CN() self.cfg.MODEL.NAME = None @@ -125,6 +128,7 @@ def __init__(self): self.cfg.LR_SCHEDULE = CN() self.cfg.LR_SCHEDULE.NAME = "cosine_anneal_schedule" self.cfg.LR_SCHEDULE.ARGS = None + self.cfg.AMP = True # Validation self.cfg.PER_ITERATION = None diff --git a/fgvclib/models/backbones/swin_t.py b/fgvclib/models/backbones/swin_t.py new file mode 100644 index 0000000..0151722 --- /dev/null +++ b/fgvclib/models/backbones/swin_t.py @@ -0,0 +1,36 @@ +import timm +import torch +from fgvclib.models.backbones import backbone + + +def load_model_weights(model, model_path): + ### reference https://github.com/TACJu/TransFG + ### thanks a lot. + state = torch.load(model_path, map_location='cpu') + for key in model.state_dict(): + if 'num_batches_tracked' in key: + continue + p = model.state_dict()[key] + if key in state['state_dict']: + ip = state['state_dict'][key] + if p.shape == ip.shape: + p.data.copy_(ip.data) # Copy the data of parameters + else: + print('could not load layer: {}, mismatch shape {} ,{}'.format(key, (p.shape), (ip.shape))) + else: + print('could not load layer: {}, not in checkpoint'.format(key)) + return model + + +@backbone("swint") +def swint(cfg): + backbone = timm.create_model('swin_large_patch4_window12_384_in22k', pretrained=cfg['pretrained']) + backbone.train() + return backbone + + +def vit16(cfg): + backbone = timm.create_model('vit_base_patch16_224_miil_in21k', pretrained=False) + backbone = load_model_weights(backbone, cfg['pretrained']) + backbone.train() + return backbone diff --git a/fgvclib/models/backbones/vit.py b/fgvclib/models/backbones/vit.py new file mode 100644 index 0000000..23c328c --- /dev/null +++ b/fgvclib/models/backbones/vit.py @@ -0,0 +1,80 @@ +import logging +import torch +import torch.nn as nn + +from torch.nn import Dropout, Conv2d +from torch.nn.modules.utils import _pair +from fgvclib.models.backbones import backbone + +# official pretrain weights +model_urls = { + 'ViT-B_16': 'https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz', + 'ViT-B_32': 'https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_32.npz', + 'ViT-L_16': 'https://storage.googleapis.com/vit_models/imagenet21k/ViT-L_16.npz', + 'ViT-L_32': 'https://storage.googleapis.com/vit_models/imagenet21k/ViT-L_32.npz', + 'ViT-H_14': 'https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz' +} + +cfgs = { + 'ViT-B_16', + 'ViT-B_32', + 'ViT-L_16', + 'ViT-L_32', + 'ViT-H_14' +} + +logger = logging.getLogger(__name__) + + +class Embeddings(nn.Module): + """Construct the embeddings from patch, position embeddings. + """ + + def __init__(self, cfg: dict, img_size, in_channels=3): + super(Embeddings, self).__init__() + self.hybrid = None + img_size = _pair(img_size) + + patch_size = cfg["patch_size"] + if cfg['split'] == 'non-overlap': + n_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) + self.patch_embeddings = Conv2d(in_channels=in_channels, + out_channels=cfg['hidden_size'], + kernel_size=patch_size, + stride=patch_size) + elif cfg['split'] == 'overlap': + n_patches = ((img_size[0] - patch_size) // cfg['slide_step'] + 1) * ( + (img_size[1] - patch_size) // cfg['slide_step'] + 1) + self.patch_embeddings = Conv2d(in_channels=in_channels, + out_channels=cfg['hidden_size'], + kernel_size=patch_size, + stride=(cfg['slide_step'], cfg['slide_step'])) + self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches + 1, cfg['hidden_size'])) + self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg['hidden_size'])) + + self.dropout = Dropout(cfg['dropout_rate']) + + def forward(self, x): + B = x.shape[0] + cls_tokens = self.cls_token.expand(B, -1, -1) + + if self.hybrid: + x = self.hybrid_model(x) + x = self.patch_embeddings(x) + x = x.flatten(2) + x = x.transpose(-1, -2) + x = torch.cat((cls_tokens, x), dim=1) + + embeddings = x + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings + + +def _vit(cfg, model_name="ViT-B_16", **kwargs): + assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name) + return Embeddings(cfg, img_size=cfg['image_size']) + + +@backbone("vit16") +def vit16(cfg): + return _vit(cfg, "ViT-B_16") diff --git a/fgvclib/models/encoders/fpn.py b/fgvclib/models/encoders/fpn.py new file mode 100644 index 0000000..53ea3f7 --- /dev/null +++ b/fgvclib/models/encoders/fpn.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn + +from fgvclib.models.encoders import encoder + + +class FPN(nn.Module): + + def __init__(self, inputs: dict, fpn_size: int, proj_type: str, upsample_type: str): + """ + inputs : dictionary contains torch.Tensor + which comes from backbone output + fpn_size: integer, fpn + proj_type: + in ["Conv", "Linear"] + upsample_type: + in ["Bilinear", "Conv", "Fc"] + for convolution neural network (e.g. ResNet, EfficientNet), recommand 'Bilinear'. + for Vit, "Fc". and Swin-T, "Conv" + """ + super(FPN, self).__init__() + assert proj_type in ["Conv", "Linear"], \ + "FPN projection type {} were not support yet, please choose type 'Conv' or 'Linear'".format(proj_type) + assert upsample_type in ["Bilinear", "Conv"], \ + "FPN upsample type {} were not support yet, please choose type 'Bilinear' or 'Conv'".format(proj_type) + + self.fpn_size = fpn_size + self.upsample_type = upsample_type + inp_names = [name for name in inputs] + + for i, node_name in enumerate(inputs): + ### projection module + if proj_type == "Conv": + m = nn.Sequential( + nn.Conv2d(inputs[node_name].size(1), inputs[node_name].size(1), 1), + nn.ReLU(), + nn.Conv2d(inputs[node_name].size(1), fpn_size, 1) + ) + elif proj_type == "Linear": + m = nn.Sequential( + nn.Linear(inputs[node_name].size(-1), inputs[node_name].size(-1)), + nn.ReLU(), + nn.Linear(inputs[node_name].size(-1), fpn_size), + ) + self.add_module("Proj_" + node_name, m) + + ### upsample module + if upsample_type == "Conv" and i != 0: + assert len(inputs[node_name].size()) == 3 # B, S, C + in_dim = inputs[node_name].size(1) + out_dim = inputs[inp_names[i - 1]].size(1) + if in_dim != out_dim: + m = nn.Conv1d(in_dim, out_dim, 1) # for spatial domain + else: + m = nn.Identity() + self.add_module("Up_" + node_name, m) + + if upsample_type == "Bilinear": + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') + + def upsample_add(self, x0: torch.Tensor, x1: torch.Tensor, x1_name: str): + """ + return Upsample(x1) + x1 + """ + if self.upsample_type == "Bilinear": + if x1.size(-1) != x0.size(-1): + x1 = self.upsample(x1) + else: + x1 = getattr(self, "Up_" + x1_name)(x1) + return x1 + x0 + + def forward(self, x): + """ + x : dictionary + { + "node_name1": feature1, + "node_name2": feature2, ... + } + """ + ### project to same dimension + hs = [] + for i, name in enumerate(x): + x[name] = getattr(self, "Proj_" + name)(x[name]) + hs.append(name) + + for i in range(len(hs) - 1, 0, -1): + x1_name = hs[i] + x0_name = hs[i - 1] + x[x0_name] = self.upsample_add(x[x0_name], + x[x1_name], + x1_name) + return x + + +@encoder("fpn") +def fpn(inputs, cfg: dict): + return FPN(inputs=inputs, fpn_size=cfg['fpn_size'], proj_type=cfg['proj_type'], upsample_type=cfg['upsample_type']) diff --git a/fgvclib/models/encoders/transformer_encoder.py b/fgvclib/models/encoders/transformer_encoder.py new file mode 100644 index 0000000..07d9b9d --- /dev/null +++ b/fgvclib/models/encoders/transformer_encoder.py @@ -0,0 +1,203 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import copy +from ..backbones.vit import Embeddings +import logging +import math +from os.path import join as pjoin +import torch +import torch.nn as nn +from torch.nn import Dropout, Softmax, Linear, LayerNorm + + +from fgvclib.models.encoders import encoder +from fgvclib.models.heads.mlp import mlp + +logger = logging.getLogger(__name__) + +ATTENTION_Q = "MultiHeadDotProductAttention_1/query" +ATTENTION_K = "MultiHeadDotProductAttention_1/key" +ATTENTION_V = "MultiHeadDotProductAttention_1/value" +ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" +FC_0 = "MlpBlock_3/Dense_0" +FC_1 = "MlpBlock_3/Dense_1" +ATTENTION_NORM = "LayerNorm_0" +MLP_NORM = "LayerNorm_2" + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +class Attention(nn.Module): + def __init__(self, cfg: dict): + super(Attention, self).__init__() + self.num_attention_heads = cfg['num_heads'] + self.attention_head_size = int(cfg['hidden_size'] / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Linear(cfg['hidden_size'], self.all_head_size) + self.key = Linear(cfg['hidden_size'], self.all_head_size) + self.value = Linear(cfg['hidden_size'], self.all_head_size) + + self.out = Linear(cfg['hidden_size'], cfg['hidden_size']) + self.attn_dropout = Dropout(cfg['attention_dropout_rate']) + self.proj_dropout = Dropout(cfg['attention_dropout_rate']) + + self.softmax = Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.softmax(attention_scores) + weights = attention_probs + attention_probs = self.attn_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + attention_output = self.out(context_layer) + attention_output = self.proj_dropout(attention_output) + return attention_output, weights + + +class Block(nn.Module): + def __init__(self, cfg: dict): + super(Block, self).__init__() + self.hidden_size = cfg['hidden_size'] + self.attention_norm = LayerNorm(cfg['hidden_size'], eps=1e-6) + self.ffn_norm = LayerNorm(cfg['hidden_size'], eps=1e-6) + self.ffn = mlp(cfg) + self.attn = Attention(cfg) + + def forward(self, x): + h = x + x = self.attention_norm(x) + x, weights = self.attn(x) + x = x + h + + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + return x, weights + + def load_from(self, weights, n_block): + ROOT = f"Transformer/encoderblock_{n_block}" + with torch.no_grad(): + query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, + self.hidden_size).t() + key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() + value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, + self.hidden_size).t() + out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, + self.hidden_size).t() + + query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) + key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) + value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) + out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) + + self.attn.query.weight.copy_(query_weight) + self.attn.key.weight.copy_(key_weight) + self.attn.value.weight.copy_(value_weight) + self.attn.out.weight.copy_(out_weight) + self.attn.query.bias.copy_(query_bias) + self.attn.key.bias.copy_(key_bias) + self.attn.value.bias.copy_(value_bias) + self.attn.out.bias.copy_(out_bias) + + mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() + mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() + mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() + mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() + + self.ffn.fc1.weight.copy_(mlp_weight_0) + self.ffn.fc2.weight.copy_(mlp_weight_1) + self.ffn.fc1.bias.copy_(mlp_bias_0) + self.ffn.fc2.bias.copy_(mlp_bias_1) + + self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) + self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) + self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) + self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) + + +class Part_Attention(nn.Module): + def __init__(self): + super(Part_Attention, self).__init__() + + def forward(self, x): + length = len(x) + last_map = x[0] + for i in range(1, length): + last_map = torch.matmul(x[i], last_map) + last_map = last_map[:, :, 0, 1:] + + _, max_inx = last_map.max(2) + return _, max_inx + + +class Encoder(nn.Module): + def __init__(self, cfg: dict): + super(Encoder, self).__init__() + self.layer = nn.ModuleList() + for _ in range(cfg['num_layers'] - 1): + layer = Block(cfg) + self.layer.append(copy.deepcopy(layer)) + self.part_select = Part_Attention() + self.part_layer = Block(cfg) + self.part_norm = LayerNorm(cfg['hidden_size'], eps=1e-6) + + def forward(self, hidden_states): + attn_weights = [] + for layer in self.layer: + hidden_states, weights = layer(hidden_states) + attn_weights.append(weights) + part_num, part_inx = self.part_select(attn_weights) + part_inx = part_inx + 1 + parts = [] + B, num = part_inx.shape + for i in range(B): + parts.append(hidden_states[i, part_inx[i, :]]) + parts = torch.stack(parts).squeeze(1) + concat = torch.cat((hidden_states[:, 0].unsqueeze(1), parts), dim=1) + part_states, part_weights = self.part_layer(concat) + part_encoded = self.part_norm(part_states) + + return part_encoded + + +class Transformer(nn.Module): + def __init__(self, cfg: dict, img_size): + super(Transformer, self).__init__() + self.embeddings = Embeddings(cfg, img_size=img_size) + self.encoder = Encoder(cfg) + + def forward(self, input_ids): + embedding_output = self.embeddings(input_ids) + part_encoded = self.encoder(embedding_output) + return part_encoded + + +@encoder("transformer_encoder") +def transformer_encoder(cfg: dict): + return Transformer(cfg, img_size=cfg['img_size']) diff --git a/fgvclib/models/heads/gcn_combiner.py b/fgvclib/models/heads/gcn_combiner.py new file mode 100644 index 0000000..580f278 --- /dev/null +++ b/fgvclib/models/heads/gcn_combiner.py @@ -0,0 +1,102 @@ +import torch +import torch.nn as nn +from typing import Union +import copy + +from fgvclib.models.heads import head + + +class GCNCombiner(nn.Module): + + def __init__(self, + total_num_selects: int, + num_classes: int, + inputs: Union[dict, None] = None, + proj_size: Union[int, None] = None, + fpn_size: Union[int, None] = None): + """ + If building backbone without FPN, set fpn_size to None and MUST give + 'inputs' and 'proj_size', the reason of these setting is to constrain the + dimension of graph convolutional network input. + """ + super(GCNCombiner, self).__init__() + + assert inputs is not None or fpn_size is not None, \ + "To build GCN combiner, you must give one features dimension." + + ### auto-proj + self.fpn_size = fpn_size + if fpn_size is None: + for name in inputs: + if len(name) == 4: + in_size = inputs[name].size(1) + elif len(name) == 3: + in_size = inputs[name].size(2) + else: + raise ValueError("The size of output dimension of previous must be 3 or 4.") + m = nn.Sequential( + nn.Linear(in_size, proj_size), + nn.ReLU(), + nn.Linear(proj_size, proj_size) + ) + self.add_module("proj_" + name, m) + self.proj_size = proj_size + else: + self.proj_size = fpn_size + + ### build one layer structure (with adaptive module) + num_joints = total_num_selects // 32 + + self.param_pool0 = nn.Linear(total_num_selects, num_joints) + + A = torch.eye(num_joints) / 100 + 1 / 100 + self.adj1 = nn.Parameter(copy.deepcopy(A)) + self.conv1 = nn.Conv1d(self.proj_size, self.proj_size, 1) + self.batch_norm1 = nn.BatchNorm1d(self.proj_size) + + self.conv_q1 = nn.Conv1d(self.proj_size, self.proj_size // 4, 1) + self.conv_k1 = nn.Conv1d(self.proj_size, self.proj_size // 4, 1) + self.alpha1 = nn.Parameter(torch.zeros(1)) + + ### merge information + self.param_pool1 = nn.Linear(num_joints, 1) + + #### class predict + self.dropout = nn.Dropout(p=0.1) + self.classifier = nn.Linear(self.proj_size, num_classes) + + self.tanh = nn.Tanh() + + def forward(self, x): + """ + """ + hs = [] + for name in x: + if self.fpn_size is None: + hs.append(getattr(self, "proj_" + name)(x[name])) + else: + hs.append(x[name]) + hs = torch.cat(hs, dim=1).transpose(1, 2).contiguous() # B, S', C --> B, C, S + hs = self.param_pool0(hs) + ### adaptive adjacency + q1 = self.conv_q1(hs).mean(1) + k1 = self.conv_k1(hs).mean(1) + A1 = self.tanh(q1.unsqueeze(-1) - k1.unsqueeze(1)) + A1 = self.adj1 + A1 * self.alpha1 + ### graph convolution + hs = self.conv1(hs) + hs = torch.matmul(hs, A1) + hs = self.batch_norm1(hs) + ### predict + hs = self.param_pool1(hs) + hs = self.dropout(hs) + hs = hs.flatten(1) + hs = self.classifier(hs) + + return hs + + +@head("GCN_combiner") +def GCN_combiner(cfg: dict): + return GCNCombiner(total_num_selects=cfg['total_num_selects'], num_classes=cfg['num_classes'], + fpn_size=cfg['fpn_size']) diff --git a/fgvclib/models/heads/mlp.py b/fgvclib/models/heads/mlp.py new file mode 100644 index 0000000..a5c71a5 --- /dev/null +++ b/fgvclib/models/heads/mlp.py @@ -0,0 +1,49 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +from torch.nn import Dropout, Linear + +from fgvclib.models.heads import head + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class Mlp(nn.Module): + def __init__(self, hidden_size, mlp_dim, dropout_rate): + super(Mlp, self).__init__() + self.fc1 = Linear(hidden_size, mlp_dim) + self.fc2 = Linear(mlp_dim, hidden_size) + self.act_fn = ACT2FN["gelu"] + self.dropout = Dropout(dropout_rate) + + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x = self.act_fn(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +@head("mlp") +def mlp(cfg: dict) -> Mlp: + assert 'hidden_size' in cfg.keys() + assert 'mlp_dim' in cfg.keys() + assert 'dropout_rate' in cfg.keys() + return Mlp(hidden_size=cfg['hidden_size'], mlp_dim=cfg['mlp_dim'], dropout_rate=cfg['dropout_rate']) diff --git a/fgvclib/models/necks/weakly_selector.py b/fgvclib/models/necks/weakly_selector.py new file mode 100644 index 0000000..0390c4a --- /dev/null +++ b/fgvclib/models/necks/weakly_selector.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +from typing import Union + +from fgvclib.models.necks import neck + + +class WeaklySelector(nn.Module): + + def __init__(self, inputs: dict, num_classes: int, num_select: dict, fpn_size: Union[int, None] = None): + """ + inputs: dictionary contain torch.Tensors, which comes from backbone + [Tensor1(hidden feature1), Tensor2(hidden feature2)...] + Please note that if len(features.size) equal to 3, the order of dimension must be [B,S,C], + S mean the spatial domain, and if len(features.size) equal to 4, the order must be [B,C,H,W] + + """ + super(WeaklySelector, self).__init__() + + self.num_select = num_select + + self.fpn_size = fpn_size + ### build classifier + if self.fpn_size is None: + self.num_classes = num_classes + for name in inputs: + fs_size = inputs[name].size() + if len(fs_size) == 3: + in_size = fs_size[2] + elif len(fs_size) == 4: + in_size = fs_size[1] + m = nn.Linear(in_size, num_classes) + self.add_module("classifier_l_" + name, m) + + def forward(self, x, logits=None): + """ + x : + dictionary contain the features maps which + come from your choosen layers. + size must be [B, HxW, C] ([B, S, C]) or [B, C, H, W]. + [B,C,H,W] will be transpose to [B, HxW, C] automatically. + """ + i = 0 + if self.fpn_size is None: + logits = {} + selections = {} + for name in x: + if len(x[name].size()) == 4: + B, C, H, W = x[name].size() + x[name] = x[name].view(B, C, H * W).permute(0, 2, 1).contiguous() + C = x[name].size(-1) + if self.fpn_size is None: + logits[name] = getattr(self, "classifier_l_" + name)(x[name]) + + probs = torch.softmax(logits[name], dim=-1) + selections[name] = [] + preds_1 = [] + preds_0 = [] + num_select = self.num_select[i][name] + i = i + 1 + for bi in range(logits[name].size(0)): + max_ids, _ = torch.max(probs[bi], dim=-1) + confs, ranks = torch.sort(max_ids, descending=True) + sf = x[name][bi][ranks[:num_select]] + nf = x[name][bi][ranks[num_select:]] # calculate + selections[name].append(sf) # [num_selected, C] + preds_1.append(logits[name][bi][ranks[:num_select]]) + preds_0.append(logits[name][bi][ranks[num_select:]]) + + selections[name] = torch.stack(selections[name]) + preds_1 = torch.stack(preds_1) + preds_0 = torch.stack(preds_0) + + logits["select_" + name] = preds_1 + logits["drop_" + name] = preds_0 + + return selections + + +@neck("weakly_selector") +def weakly_selector(inputs, cfg: dict): + return WeaklySelector(inputs=inputs, num_classes=cfg['num_classes'], num_select=cfg['num_selects'], + fpn_size=cfg['fpn_size']) diff --git a/fgvclib/models/sotas/api_net.py b/fgvclib/models/sotas/api_net.py index badf664..5e50580 100644 --- a/fgvclib/models/sotas/api_net.py +++ b/fgvclib/models/sotas/api_net.py @@ -14,8 +14,8 @@ class APINet(FGVCSOTA): Link: https://github.com/PeiqinZhuang/API-Net """ - def __init__(self, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: dict): - super().__init__(backbone, encoder, necks, heads, criterions) + def __init__(self, cfg: dict, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): + super().__init__(cfg, backbone, encoder, necks, heads, criterions) def forward(self, images, targets=None): diff --git a/fgvclib/models/sotas/cal.py b/fgvclib/models/sotas/cal.py index 466ca9d..cf00c63 100644 --- a/fgvclib/models/sotas/cal.py +++ b/fgvclib/models/sotas/cal.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from yacs.config import CfgNode from .sota import FGVCSOTA from fgvclib.criterions.utils import LossItem @@ -67,13 +68,14 @@ class WSDAN_CAL(FGVCSOTA): Link: https://github.com/raoyongming/CAL """ - def __init__(self, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: dict): - super().__init__(backbone, encoder, necks, heads, criterions) - self.num_classes = 200 + def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): + super().__init__(cfg, backbone, encoder, necks, heads, criterions) + + self.out_channels = 2048 + self.num_classes = cfg.NUM_CLASS self.M = 32 self.net = 'resnet101' - - self.register_buffer('feature_center', torch.zeros(self.num_classes, self.M * 2048)) # 32 * 2048 + self.register_buffer('feature_center', torch.zeros(self.num_classes, self.M * self.out_channels)) # 32 * 2048 def infer(self, x): diff --git a/fgvclib/models/sotas/mcl.py b/fgvclib/models/sotas/mcl.py index ffde2a3..e181fc4 100644 --- a/fgvclib/models/sotas/mcl.py +++ b/fgvclib/models/sotas/mcl.py @@ -1,5 +1,6 @@ from torch import nn +from yacs.config import CfgNode from fgvclib.models.sotas.sota import FGVCSOTA from fgvclib.models.sotas import fgvcmodel @@ -12,14 +13,16 @@ class MCL(FGVCSOTA): Link: https://github.com/PRIS-CV/Mutual-Channel-Loss """ - def __init__(self, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): - super().__init__(backbone, encoder, necks, heads, criterions) + def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): + super().__init__(cfg, backbone, encoder, necks, heads, criterions) + self.num_classes = cfg.CLASS_NUM + def forward(self, x, targets=None): x = self.backbone(x) if self.training: losses = list() - losses.extend(self.criterions['mutual_channel_loss']['fn'](x, targets, self.heads.get_class_num())) + losses.extend(self.criterions['mutual_channel_loss']['fn'](x, targets, self.num_classes)) x = self.encoder(x) x = x.view(x.size(0), -1) diff --git a/fgvclib/models/sotas/pmg.py b/fgvclib/models/sotas/pmg.py index 5aa394f..e8f1f7c 100644 --- a/fgvclib/models/sotas/pmg.py +++ b/fgvclib/models/sotas/pmg.py @@ -1,5 +1,6 @@ import torch.nn as nn import random +from yacs.config import CfgNode from fgvclib.models.sotas.sota import FGVCSOTA from fgvclib.criterions.utils import LossItem @@ -14,8 +15,8 @@ class PMG(FGVCSOTA): """ - def __init__(self, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): - super().__init__(backbone, encoder, necks, heads, criterions) + def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): + super().__init__(cfg, backbone, encoder, necks, heads, criterions) self.outputs_num = 4 diff --git a/fgvclib/models/sotas/resnet50.py b/fgvclib/models/sotas/resnet50.py index 0769cc7..e9f8e62 100644 --- a/fgvclib/models/sotas/resnet50.py +++ b/fgvclib/models/sotas/resnet50.py @@ -1,4 +1,5 @@ import torch.nn as nn +from yacs.config import CfgNode from fgvclib.models.sotas.sota import FGVCSOTA from fgvclib.models.sotas import fgvcmodel @@ -9,8 +10,8 @@ @fgvcmodel("ResNet50") class ResNet50(FGVCSOTA): - def __init__(self, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: dict): - super().__init__(backbone, encoder, necks, heads, criterions) + def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): + super().__init__(cfg, backbone, encoder, necks, heads, criterions) def forward(self, x, targets=None): x = self.infer(x) diff --git a/fgvclib/models/sotas/sota.py b/fgvclib/models/sotas/sota.py index 5e85845..2145056 100644 --- a/fgvclib/models/sotas/sota.py +++ b/fgvclib/models/sotas/sota.py @@ -1,17 +1,22 @@ -import torch +from yacs.config import CfgNode import torch.nn as nn from thop import profile +from fgvclib.configs.utils import turn_list_to_dict as tltd + + class FGVCSOTA(nn.Module): - def __init__(self, backbone:nn.Module, encoder:nn.Module, necks:nn.Module, heads:nn.Module, criterions:nn.Module): + def __init__(self, cfg:CfgNode, backbone:nn.Module, encoder:nn.Module, necks:nn.Module, heads:nn.Module, criterions:nn.Module): super(FGVCSOTA, self).__init__() - + + self.cfg = cfg self.backbone = backbone self.necks = necks self.encoder = encoder self.heads = heads self.criterions = criterions + self.args = tltd(cfg.ARGS) def get_structure(self): ss = f"\n{'=' * 30}\n" + f"\nThe Structure of {self.__class__.__name__}:\n" + f"\n{'=' * 30}\n\n" diff --git a/fgvclib/models/sotas/swin_T.py b/fgvclib/models/sotas/swin_T.py new file mode 100644 index 0000000..22f49e9 --- /dev/null +++ b/fgvclib/models/sotas/swin_T.py @@ -0,0 +1,104 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from yacs.config import CfgNode + +from .sota import FGVCSOTA +from fgvclib.models.sotas import fgvcmodel + + +@fgvcmodel("Swin_T") +class Swin_T(FGVCSOTA): + def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): + super().__init__(cfg, backbone, encoder, necks, heads, criterions) + + ### get hidden feartues size + img_size = self.args["img_size"] + num_classes = cfg.CLASS_NUM + rand_in = torch.randn(1, 3, img_size, img_size) + outs = self.backbone(rand_in) + + ### = = = = = FPN = = = = = + self.fpn = self.encoder + fpn_size = 1536 + self.build_fpn_classifier(outs, fpn_size, num_classes) + + ### = = = = = Selector = = = = = + self.selector = self.necks + + ### = = = = = Combiner = = = = = + self.combiner = self.heads + + ### just original backbone + if not self.fpn and (not self.combiner): + for name in outs: + fs_size = outs[name].size() + if len(fs_size) == 3: + out_size = fs_size.size(-1) + elif len(fs_size) == 4: + out_size = fs_size.size(1) + else: + raise ValueError("The size of output dimension of previous must be 3 or 4.") + self.classifier = nn.Linear(out_size, num_classes) + + def build_fpn_classifier(self, inputs: dict, fpn_size: int, num_classes: int): + for name in inputs: + m = nn.Sequential( + nn.Conv1d(fpn_size, fpn_size, 1), + nn.BatchNorm1d(fpn_size), + nn.ReLU(), + nn.Conv1d(fpn_size, num_classes, 1) + ) + self.add_module("fpn_classifier_" + name, m) + + def forward_backbone(self, x): + return self.backbone(x) + + def fpn_predict(self, x: dict, logits: dict): + """ + x: [B, C, H, W] or [B, S, C] + [B, C, H, W] --> [B, H*W, C] + """ + for name in x: + ### predict on each features point + if len(x[name].size()) == 4: + B, C, H, W = x[name].size() + logit = x[name].view(B, C, H * W) + elif len(x[name].size()) == 3: + logit = x[name].transpose(1, 2).contiguous() + logits[name] = getattr(self, "fpn_classifier_" + name)(logit) + logits[name] = logits[name].transpose(1, 2).contiguous() # transpose + + def forward(self, x: torch.Tensor): + + logits = {} + x = self.forward_backbone(x) + + if self.fpn: + x = self.fpn(x) + self.fpn_predict(x, logits) + + if self.selector: + selects = self.selector(x, logits) + + if self.combiner: + comb_outs = self.combiner(selects) + logits['comb_outs'] = comb_outs + return logits + + if self.selector or self.fpn: + return logits + + ### original backbone (only predict final selected layer) + for name in x: + hs = x[name] + + if len(hs.size()) == 4: + hs = F.adaptive_avg_pool2d(hs, (1, 1)) + hs = hs.flatten(1) + else: + hs = hs.mean(1) + out = self.classifier(hs) + logits['ori_out'] = logits + + return logits diff --git a/fgvclib/models/sotas/transFG.py b/fgvclib/models/sotas/transFG.py new file mode 100644 index 0000000..7e284c7 --- /dev/null +++ b/fgvclib/models/sotas/transFG.py @@ -0,0 +1,130 @@ +# coding=utf-8 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch.nn.functional as F +import numpy as np +import torch +import torch.nn as nn +import os.path as op +from scipy import ndimage +from torch.nn import Linear +from yacs.config import CfgNode + +from fgvclib.criterions import LossItem +from fgvclib.models.sotas.sota import FGVCSOTA +from fgvclib.models.sotas import fgvcmodel + + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +@fgvcmodel("TransFG") +class TransFG(FGVCSOTA): + def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): + super().__init__(cfg, backbone, encoder, necks, heads, criterions) + + self.num_classes = cfg.CLASS_NUM + self.smoothing_value = self.args["smoothing_value"] + self.zero_head = self.args["zero_head"] + self.classifier = self.args["classifier"] + self.part_head = Linear(self.args["part_head_in_channels"], self.num_classes) + self.pretrained_weight = self.args["pretrained_weight"] + + if op.exists(self.pretrained_weight): + print(f"Loading pretraining weight in {self.pretrained_weight}.") + if self.pretrained_weight.endswith('.npz'): + self.load_from(np.load(self.pretrained_weight)) + else: + self.load_state_dict(torch.load(self.pretrained_weight)) + + + def forward(self, x, labels=None): + part_tokens = self.encoder(x) + part_logits = self.part_head(part_tokens[:, 0]) + if labels is not None: + losses = list() + if self.smoothing_value == 0: + part_loss = self.criterions['cross_entropy_loss']['fn'](part_logits.view(-1, self.num_classes), + labels.view(-1)) + else: + part_loss = self.criterions['nll_loss_labelsmoothing']['fn'](part_logits.view(-1, self.num_classes), + labels.view(-1)) + contrast_loss = con_loss(part_tokens[:, 0], labels.view(-1)) + losses.append(LossItem(name='contrast_loss', value=contrast_loss, weight=1.0)) + losses.append(LossItem(name='part_loss', value=part_loss, weight=1.0)) + return part_logits, losses + else: + return part_logits + + def load_from(self, weights): + with torch.no_grad(): + self.encoder.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) + self.encoder.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) + self.encoder.embeddings.cls_token.copy_(np2th(weights["cls"])) + self.encoder.encoder.part_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) + self.encoder.encoder.part_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) + + posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) + posemb_new = self.encoder.embeddings.position_embeddings + if posemb.size() == posemb_new.size(): + self.encoder.embeddings.position_embeddings.copy_(posemb) + else: + + ntok_new = posemb_new.size(1) + + if self.classifier == "token": + posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] + ntok_new -= 1 + else: + posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + + gs_old = int(np.sqrt(len(posemb_grid))) + gs_new = int(np.sqrt(ntok_new)) + print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) + posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) + + zoom = (gs_new / gs_old, gs_new / gs_old, 1) + posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) + posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) + posemb = np.concatenate([posemb_tok, posemb_grid], axis=1) + self.encoder.embeddings.position_embeddings.copy_(np2th(posemb)) + + for bname, block in self.encoder.encoder.named_children(): + if bname.startswith('part') == False: + for uname, unit in block.named_children(): + unit.load_from(weights, n_block=uname) + + if self.encoder.embeddings.hybrid: + self.encoder.embeddings.hybrid_model.root.conv.weight.copy_( + np2th(weights["conv_root/kernel"], conv=True)) + gn_weight = np2th(weights["gn_root/scale"]).view(-1) + gn_bias = np2th(weights["gn_root/bias"]).view(-1) + self.encoder.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) + self.encoder.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) + + for bname, block in self.encoder.embeddings.hybrid_model.body.named_children(): + for uname, unit in block.named_children(): + unit.load_from(weights, n_block=bname, n_unit=uname) + + + + +def con_loss(features, labels): + B, _ = features.shape + features = F.normalize(features) + cos_matrix = features.mm(features.t()) + pos_label_matrix = torch.stack([labels == labels[i] for i in range(B)]).float() + neg_label_matrix = 1 - pos_label_matrix + pos_cos_matrix = 1 - cos_matrix + neg_cos_matrix = cos_matrix - 0.4 + neg_cos_matrix[neg_cos_matrix < 0] = 0 + loss = (pos_cos_matrix * pos_label_matrix).sum() + (neg_cos_matrix * neg_label_matrix).sum() + loss /= (B * B) + return loss diff --git a/fgvclib/utils/lr_schedules/constant_schedule.py b/fgvclib/utils/lr_schedules/constant_schedule.py new file mode 100644 index 0000000..3ae9c14 --- /dev/null +++ b/fgvclib/utils/lr_schedules/constant_schedule.py @@ -0,0 +1,10 @@ +from torch.optim.lr_scheduler import LambdaLR + +from .lr_schedule import LRSchedule + + +class ConstantSchedule(LambdaLR, LRSchedule): + """ Constant learning rate schedule. + """ + def __init__(self, optimizer, total_epoch=-1): + super(ConstantSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=total_epoch) \ No newline at end of file diff --git a/fgvclib/utils/lr_schedules/cosine_decay_schedule.py b/fgvclib/utils/lr_schedules/cosine_decay_schedule.py new file mode 100644 index 0000000..d2157ae --- /dev/null +++ b/fgvclib/utils/lr_schedules/cosine_decay_schedule.py @@ -0,0 +1,33 @@ +import math +import numpy as np + + +def cosine_decay(max_epochs, warmup_batchs, max_lr, batchs: int, decay_type: int = 1): + total_batchs = max_epochs * batchs + iters = np.arange(total_batchs - warmup_batchs) + + if decay_type == 1: + schedule = np.array([1e-12 + 0.5 * (max_lr - 1e-12) * (1 + \ + math.cos(math.pi * t / total_batchs)) for t in + iters]) + elif decay_type == 2: + schedule = max_lr * np.array([math.cos(7 * math.pi * t / (16 * total_batchs)) for t in iters]) + else: + raise ValueError("Not support this deccay type") + + if warmup_batchs > 0: + warmup_lr_schedule = np.linspace(1e-9, max_lr, warmup_batchs) + schedule = np.concatenate((warmup_lr_schedule, schedule)) + + return schedule + + +def get_lr(optimizer): + for param_group in optimizer.param_groups: + if param_group["lr"] is not None: + return param_group["lr"] + + +def adjust_lr(iteration, optimizer, schedule): + for param_group in optimizer.param_groups: + param_group["lr"] = schedule[iteration] diff --git a/fgvclib/utils/update_function/__init__.py b/fgvclib/utils/update_function/__init__.py new file mode 100644 index 0000000..374fe27 --- /dev/null +++ b/fgvclib/utils/update_function/__init__.py @@ -0,0 +1,34 @@ +import os +import importlib + +from fgvclib.utils.lr_schedules.lr_schedule import LRSchedule + + +__UPDATE_FN_DICT__ = {} + + +def get_update_function(name) -> LRSchedule: + r"""Return the dataset with the given name. + Args: + name (str): + The name of dataset. + Return: + (FGVCDataset): The dataset contructor method. + """ + + return __UPDATE_FN_DICT__[name] + +def update_function(name): + + def register_function_fn(cls): + if name in __UPDATE_FN_DICT__: + raise ValueError("Name %s already registered!" % name) + __UPDATE_FN_DICT__[name] = cls + return cls + + return register_function_fn + +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith('.py') and not file.startswith('_'): + module_name = file[:file.find('.py')] + module = importlib.import_module('fgvclib.utils.update_function.' + module_name) \ No newline at end of file diff --git a/fgvclib/utils/update_function/general_update.py b/fgvclib/utils/update_function/general_update.py new file mode 100644 index 0000000..764f0b6 --- /dev/null +++ b/fgvclib/utils/update_function/general_update.py @@ -0,0 +1,189 @@ +from typing import Iterable +import torch +import torch.nn as nn +from torch.optim import Optimizer + + +from fgvclib.utils.update_strategy import get_update_strategy +from fgvclib.utils.logger import Logger +from fgvclib.utils.lr_schedules import LRSchedule + + +def general_update( + model: nn.Module, optimizer: Optimizer, pbar:Iterable, lr_schedule:LRSchedule=None, + strategy:str="general_updating", use_cuda:bool=True, logger:Logger=None, + epoch:int=None, total_epoch:int=None, amp:bool=False, **kwargs +): + r"""Update the FGVC model and record losses. + + Args: + model (nn.Module): The FGVC model. + optimizer (Optimizer): The Logger object. + pbar (Iterable): The iterable object provide training data. + lr_schedule (LRSchedule): The lr schedule updating class. + strategy (string): The update strategy. + use_cuda (boolean): Whether to use GPU to train the model. + logger (Logger): The Logger object. + epoch (int): The current epoch number. + total_epoch (int): The total epoch number. + """ + + model.train() + mean_loss = 0. + for batch_idx, train_data in enumerate(pbar): + losses_info = get_update_strategy(strategy)(model, train_data, optimizer, use_cuda, amp) + mean_loss = (mean_loss * batch_idx + losses_info['iter_loss']) / (batch_idx + 1) + losses_info.update({"mean_loss": mean_loss}) + logger(losses_info, step=batch_idx) + pbar.set_postfix(losses_info) + if lr_schedule.update_level == 'batch_update': + lr_schedule.step(optimizer=optimizer, batch_idx=batch_idx, batch_size=len(train_data), current_epoch=epoch, total_epoch=total_epoch) + + if lr_schedule.update_level == 'epoch_update': + lr_schedule.step(optimizer=optimizer, current_epoch=epoch, total_epoch=total_epoch) + + +def update_vitmodel(model: nn.Module, optimizer: Optimizer, scheduler, pbar: Iterable, + strategy: str = "vit_updating", + use_cuda: bool = True, logger: Logger = None): + r"""Update the FGVC model and record losses. + + Args: + model (nn.Module): The FGVC model. + optimizer (Optimizer): The Logger object. + scheduler : The scheduler strategy + pbar (Iterable): A iterable object provide training data. + strategy (string): The update strategy. + use_cuda (boolean): Whether to use GPU to train the model. + logger (Logger): The Logger object. + """ + model.train() + mean_loss = 0. + for batch_idx, train_data in enumerate(pbar): + losses_info = get_update_strategy(strategy)(model, train_data, optimizer, scheduler, use_cuda) + mean_loss = (mean_loss * batch_idx + losses_info['iter_loss']) / (batch_idx + 1) + losses_info.update({"mean_loss": mean_loss}) + logger(losses_info, step=batch_idx) + pbar.set_postfix(losses_info) + + +def update_swintModel(args, epoch, model, scaler, amp_context, optimizer, schedule, train_bar, logger: Logger = None): + optimizer.zero_grad() + total_batchs = len(train_bar) # just for log + show_progress = [x / 10 for x in range(11)] # just for log + progress_i = 0 + for batch_id, (datas, labels) in enumerate(train_bar): + model.train() + """ = = = = adjust learning rate = = = = """ + iterations = epoch * len(train_bar) + batch_id + adjust_lr(iterations, optimizer, schedule) + + batch_size = labels.size(0) + + """ = = = = forward and calculate loss = = = = """ + datas, labels = datas.cuda(), labels.cuda() + datas, labels = Variable(datas), Variable(labels) + with amp_context(): + """ + [Model Return] + FPN + Selector + Combiner --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1', 'comb_outs' + FPN + Selector --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1' + FPN --> return 'layer1', 'layer2', 'layer3', 'layer4' (depend on your setting) + ~ --> return 'ori_out' + + [Retuen Tensor] + 'preds_0': logit has not been selected by Selector. + 'preds_1': logit has been selected by Selector. + 'comb_outs': The prediction of combiner. + """ + outs = model(datas) + + loss = 0. + for name in outs: + + if "select_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + if args.lambda_s != 0: + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + loss_s = nn.CrossEntropyLoss()(logit, + labels.unsqueeze(1).repeat(1, S).flatten(0)) + loss += args.lambda_s * loss_s + else: + loss_s = 0.0 + + elif "drop_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + + if args.lambda_n != 0: + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + n_preds = nn.Tanh()(logit) + labels_0 = torch.zeros([batch_size * S, args.num_classes]) - 1 + labels_0 = labels_0.cuda() + loss_n = nn.MSELoss()(n_preds, labels_0) + loss += args.lambda_n * loss_n + else: + loss_n = 0.0 + + elif "layer" in name: + if not args.use_fpn: + raise ValueError("FPN not use here.") + if args.lambda_b != 0: + ### here using 'layer1'~'layer4' is default setting, you can change to your own + loss_b = nn.CrossEntropyLoss()(outs[name].mean(1), labels) + loss += args.lambda_b * loss_b + else: + loss_b = 0.0 + + elif "comb_outs" in name: + if not args.use_combiner: + raise ValueError("Combiner not use here.") + + if args.lambda_c != 0: + loss_c = nn.CrossEntropyLoss()(outs[name], labels) + loss += args.lambda_c * loss_c + + elif "ori_out" in name: + loss_ori = F.cross_entropy(outs[name], labels) + loss += loss_ori + + loss /= args.update_freq + + """ = = = = calculate gradient = = = = """ + if args.use_amp: + scaler.scale(loss).backward() + else: + loss.backward() + + """ = = = = update model = = = = """ + if (batch_id + 1) % args.update_freq == 0: + if args.use_amp: + scaler.step(optimizer) + scaler.update() # next batch + else: + optimizer.step() + optimizer.zero_grad() + + """ log """ + if (batch_id + 1) % args.log_freq == 0: + model.eval() + msg = {} + msg['info/epoch'] = epoch + 1 + msg['info/lr'] = get_lr(optimizer) + msg['loss'] = loss.item() + cal_train_metrics(args, msg, outs, labels, batch_size) + logger(msg) + + train_progress = (batch_id + 1) / total_batchs + # print(train_progress, show_progress[progress_i]) + if train_progress > show_progress[progress_i]: + print(".." + str(int(show_progress[progress_i] * 100)) + "%", end='', flush=True) + progress_i += 1 + + + \ No newline at end of file diff --git a/fgvclib/utils/update_strategy/general_updating.py b/fgvclib/utils/update_strategy/general_updating.py index be49923..17d5694 100644 --- a/fgvclib/utils/update_strategy/general_updating.py +++ b/fgvclib/utils/update_strategy/general_updating.py @@ -1,18 +1,30 @@ from torch.autograd import Variable import typing as t +from torch.cuda.amp import autocast, GradScaler from fgvclib.criterions import compute_loss_value, detach_loss_value -def general_updating(model, train_data, optimizer, use_cuda=True, **kwargs) -> t.Dict: +def general_updating(model, train_data, optimizer, use_cuda=True, amp=False, **kwargs) -> t.Dict: + if amp: + scaler = GradScaler() inputs, targets = train_data if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() inputs, targets = Variable(inputs), Variable(targets) - out, losses = model(inputs, targets) - total_loss = compute_loss_value(losses) - total_loss.backward() - optimizer.step() + if amp: + with autocast(): + out, losses = model(inputs, targets) + total_loss = compute_loss_value(losses) + scaler.scale(total_loss).backward() + scaler.step(optimizer) + scaler.update() + else: + out, losses = model(inputs, targets) + total_loss = compute_loss_value(losses) + total_loss.backward() + optimizer.step() + optimizer.zero_grad() losses_info = detach_loss_value(losses) losses_info.update({"iter_loss": total_loss.item()}) diff --git a/fgvclib/utils/update_strategy/vit_updating.py b/fgvclib/utils/update_strategy/vit_updating.py new file mode 100644 index 0000000..b393e69 --- /dev/null +++ b/fgvclib/utils/update_strategy/vit_updating.py @@ -0,0 +1,24 @@ +import numpy as np +import torch +from torch.autograd import Variable +import typing as t +from fgvclib.criterions import compute_loss_value, detach_loss_value + + +def vit_updating(model, train_data, optimizer, scheduler, use_cuda=True) -> t.Dict: + inputs, targets = train_data + if use_cuda: + inputs, targets = inputs.cuda(), targets.cuda() + inputs, targets = Variable(inputs), Variable(targets) + out, losses = model(inputs, targets) + total_loss = compute_loss_value(losses) + total_loss = total_loss.mean() + total_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + scheduler.step() + optimizer.zero_grad() + losses_info = detach_loss_value(losses) + losses_info.update({"iter_loss": total_loss.item()}) + + return losses_info diff --git a/main.py b/main.py index 765794c..52eaa7b 100644 --- a/main.py +++ b/main.py @@ -7,7 +7,8 @@ from fgvclib.apis import * from fgvclib.configs import FGVCConfig -from fgvclib.utils import cosine_anneal_schedule, init_distributed_mode +from fgvclib.utils import init_distributed_mode +from fgvclib.utils.update_function.general_update import general_update def train(cfg: CfgNode): @@ -90,12 +91,10 @@ def train(cfg: CfgNode): logger(f'Epoch: {epoch + 1} / {cfg.EPOCH_NUM} Training') - - - update_model( + general_update( model, optimizer, train_bar, strategy=cfg.UPDATE_STRATEGY, use_cuda=cfg.USE_CUDA, lr_schedule=lr_schedule, - logger=logger, epoch=epoch, total_epoch=cfg.EPOCH_NUM + logger=logger, epoch=epoch, total_epoch=cfg.EPOCH_NUM, amp=cfg.AMP ) test_bar = tqdm(test_loader) @@ -112,7 +111,7 @@ def train(cfg: CfgNode): model_with_ddp = model.module else: model_with_ddp = model - save_model(cfg=cfg, model=model, logger=logger) + save_model(cfg=cfg, model=model_with_ddp, logger=logger) logger.finish() def predict(cfg: CfgNode): @@ -157,6 +156,7 @@ def predict(cfg: CfgNode): config = FGVCConfig() config.load(args.config) cfg = config.cfg + set_seed(cfg.SEED) if args.distributed: cfg.DISTRIBUTED = args.distributed cfg.GPU = args.gpu From bc5c62c4f310969040a9ab8c1c5f5647b360af7c Mon Sep 17 00:00:00 2001 From: xin-ran-w Date: Tue, 17 Jan 2023 20:42:16 +0800 Subject: [PATCH 02/10] update transFG --- .gitignore | 1 + configs/resnet/resnet50.yml | 2 +- configs/transfg/transFG_ViT_B_16.yml | 11 +- fgvclib/apis/build.py | 19 +- fgvclib/configs/config.py | 3 +- fgvclib/criterions/mutual_channel_loss.py | 3 +- fgvclib/criterions/nll_loss_labelsmoothing.py | 32 +++ fgvclib/models/heads/mlp.py | 2 +- fgvclib/utils/lr_schedules/__init__.py | 5 +- .../utils/lr_schedules/adjusting_schedule.py | 24 +- .../utils/lr_schedules/constant_schedule.py | 4 +- .../lr_schedules/cosine_anneal_schedule.py | 16 +- .../lr_schedules/cosine_decay_schedule.py | 18 +- fgvclib/utils/lr_schedules/lr_schedule.py | 6 +- .../lr_schedules/warmup_cosine_schedule.py | 28 +++ .../utils/update_function/general_update.py | 153 +----------- fgvclib/utils/update_function/update_swint.py | 233 ++++++++++++++++++ fgvclib/utils/update_strategy/__init__.py | 42 ++-- .../{general_updating.py => general.py} | 5 +- ... => progressive_consistency_constraint.py} | 4 +- .../progressive_updating_with_jigsaw.py | 3 + fgvclib/utils/update_strategy/vit_updating.py | 7 +- main.py | 6 +- 23 files changed, 411 insertions(+), 216 deletions(-) create mode 100644 fgvclib/criterions/nll_loss_labelsmoothing.py create mode 100644 fgvclib/utils/lr_schedules/warmup_cosine_schedule.py create mode 100644 fgvclib/utils/update_function/update_swint.py rename fgvclib/utils/update_strategy/{general_updating.py => general.py} (88%) rename fgvclib/utils/update_strategy/{progressive_updating_consistency_constraint.py => progressive_consistency_constraint.py} (93%) diff --git a/.gitignore b/.gitignore index 1d0337c..d54c776 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ __pycache__ .ipynb_checkpoints cam_images/ .vscode +ut.py diff --git a/configs/resnet/resnet50.yml b/configs/resnet/resnet50.yml index 7c52d2a..485d607 100644 --- a/configs/resnet/resnet50.yml +++ b/configs/resnet/resnet50.yml @@ -98,7 +98,7 @@ OPTIMIZER: ITERATION_NUM: ~ EPOCH_NUM: 1 START_EPOCH: 0 -UPDATE_STRATEGY: "general_updating" +UPDATE_STRATEGY: "general_strategy" # Validation details PER_ITERATION: ~ diff --git a/configs/transfg/transFG_ViT_B_16.yml b/configs/transfg/transFG_ViT_B_16.yml index bf277a5..fd5c668 100644 --- a/configs/transfg/transFG_ViT_B_16.yml +++ b/configs/transfg/transFG_ViT_B_16.yml @@ -125,12 +125,15 @@ OPTIMIZER: necks: 0.08 heads: 0.08 -WARMUP_STEPS: 500 -NUM_STEPS: 10000 -ITERATION_NUM: ~ EPOCH_NUM: 100 START_EPOCH: 0 -UPDATE_STRATEGY: "vit_updating" +UPDATE_STRATEGY: "vit_update_strategy" +LR_SCHEDULE: + NAME: "warmup_linear_schedule" + ARGS: + - warmup_steps: 500 + - total_steps: 10000 +AMP: True # Validation details diff --git a/fgvclib/apis/build.py b/fgvclib/apis/build.py index 2c4febf..b0ffb6c 100644 --- a/fgvclib/apis/build.py +++ b/fgvclib/apis/build.py @@ -30,6 +30,7 @@ from fgvclib.utils.logger import get_logger, Logger from fgvclib.utils.interpreter import get_interpreter, Interpreter from fgvclib.utils.lr_schedules import get_lr_schedule, LRSchedule +from fgvclib.utils.update_function import get_update_function def build_model(model_cfg: CfgNode) -> FGVCSOTA: @@ -220,8 +221,8 @@ def build_sampler(sampler_cfg: CfgNode) -> Sampler: return get_sampler(sampler_cfg.NAME) -def build_lr_schedule(schedule_cfg: CfgNode) -> LRSchedule: - r"""Build metrics for evaluation. +def build_lr_schedule(optimizer, schedule_cfg: CfgNode) -> LRSchedule: + r"""Build lr_schedule for training. Args: schedule_cfg (CfgNode): The schedule config node of root config node. @@ -231,4 +232,16 @@ def build_lr_schedule(schedule_cfg: CfgNode) -> LRSchedule: """ - return get_lr_schedule(schedule_cfg.NAME)(tltd(schedule_cfg.ARGS)) + return get_lr_schedule(schedule_cfg.NAME)(optimizer, tltd(schedule_cfg.ARGS)) + +def build_update_function(cfg) -> t.Callable: + r"""Build metrics for evaluation. + + Args: + cfg (CfgNode): The root config node. + Returns: + Callable: A update_function. + + """ + + return get_update_function(cfg.UPDATE_FUNCTION) diff --git a/fgvclib/configs/config.py b/fgvclib/configs/config.py index e17fd23..47d8824 100644 --- a/fgvclib/configs/config.py +++ b/fgvclib/configs/config.py @@ -124,7 +124,8 @@ def __init__(self): self.cfg.ITERATION_NUM = None self.cfg.EPOCH_NUM = None self.cfg.START_EPOCH = None - self.cfg.UPDATE_STRATEGY = None + self.cfg.UPDATE_FUNCTION = "general_update" + self.cfg.UPDATE_STRATEGY = "general_strategy" self.cfg.LR_SCHEDULE = CN() self.cfg.LR_SCHEDULE.NAME = "cosine_anneal_schedule" self.cfg.LR_SCHEDULE.ARGS = None diff --git a/fgvclib/criterions/mutual_channel_loss.py b/fgvclib/criterions/mutual_channel_loss.py index b2f5410..4dfe631 100644 --- a/fgvclib/criterions/mutual_channel_loss.py +++ b/fgvclib/criterions/mutual_channel_loss.py @@ -7,6 +7,7 @@ from torch.nn.modules.utils import _pair from .utils import LossItem +from . import criterion class MutualChannelLoss(nn.Module): @@ -147,7 +148,7 @@ def __repr__(self) -> str: + ', count_include_pad=' + str(self.count_include_pad) + ')' - +@criterion("mutual_channel_loss") def mutual_channel_loss(cfg=None): assert 'height' in cfg.keys(), 'height must exist in parameters' assert 'cnum' in cfg.keys(), 'cnum must exist in parameters' diff --git a/fgvclib/criterions/nll_loss_labelsmoothing.py b/fgvclib/criterions/nll_loss_labelsmoothing.py new file mode 100644 index 0000000..c8a8690 --- /dev/null +++ b/fgvclib/criterions/nll_loss_labelsmoothing.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn + +from . import criterion + + +class LabelSmoothing(nn.Module): + """ + NLL loss with label smoothing. + """ + + def __init__(self, smoothing): + """ + Constructor for the LabelSmoothing module. + :param smoothing: label smoothing factor + """ + super(LabelSmoothing, self).__init__() + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + + def forward(self, x, target): + logprobs = torch.nn.functional.log_softmax(x, dim=-1) + + nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) + nll_loss = nll_loss.squeeze(1) + smooth_loss = -logprobs.mean(dim=-1) + loss = self.confidence * nll_loss + self.smoothing * smooth_loss + return loss.mean() + +@criterion("nll_loss_labelsmoothing") +def nll_loss_labelsmoothing(cfg=None): + return LabelSmoothing(smoothing=cfg['smoothing_value']) diff --git a/fgvclib/models/heads/mlp.py b/fgvclib/models/heads/mlp.py index a5c71a5..1dbc12c 100644 --- a/fgvclib/models/heads/mlp.py +++ b/fgvclib/models/heads/mlp.py @@ -42,7 +42,7 @@ def forward(self, x): @head("mlp") -def mlp(cfg: dict) -> Mlp: +def mlp(cfg: dict, **kwargs) -> Mlp: assert 'hidden_size' in cfg.keys() assert 'mlp_dim' in cfg.keys() assert 'dropout_rate' in cfg.keys() diff --git a/fgvclib/utils/lr_schedules/__init__.py b/fgvclib/utils/lr_schedules/__init__.py index 6ed8cd4..9720003 100644 --- a/fgvclib/utils/lr_schedules/__init__.py +++ b/fgvclib/utils/lr_schedules/__init__.py @@ -1,5 +1,6 @@ import os import importlib +from torch.optim.lr_scheduler import _LRScheduler from fgvclib.utils.lr_schedules.lr_schedule import LRSchedule @@ -23,8 +24,8 @@ def lr_schedule(name): def register_function_fn(cls): if name in __LR_SCHEDULE_DICT__: raise ValueError("Name %s already registered!" % name) - if not issubclass(cls, LRSchedule): - raise ValueError("Class %s is not a subclass of %s" % (cls, LRSchedule)) + # if not issubclass(cls, LRSchedule) and not issubclass(cls, _LRScheduler): + # raise ValueError("Class %s is not a subclass of %s or %s" % (cls, LRSchedule, _LRScheduler)) __LR_SCHEDULE_DICT__[name] = cls return cls diff --git a/fgvclib/utils/lr_schedules/adjusting_schedule.py b/fgvclib/utils/lr_schedules/adjusting_schedule.py index a624574..f471c11 100644 --- a/fgvclib/utils/lr_schedules/adjusting_schedule.py +++ b/fgvclib/utils/lr_schedules/adjusting_schedule.py @@ -1,18 +1,24 @@ from .lr_schedule import LRSchedule from . import lr_schedule -@lr_schedule("adjusting_schedule") + class Adjusting_Schedule(LRSchedule): - def __init__(self, cfg) -> None: - super().__init__(cfg) - self.base_rate = cfg["base_rate"] - self.base_duration = cfg["base_duration"] - self.base_lr = cfg["base_lr"] - self.update_level = 'batch_update' + def __init__(self, optimizer, base_rate, base_duration, base_lr) -> None: + super().__init__(optimizer) + + self.base_rate = base_rate + self.base_duration = base_duration + self.base_lr = base_lr + self.update_level = "batch" - def step(self, optimizer, batch_idx, current_epoch, batch_size, **kwargs): - iter = float(batch_idx) / batch_size + def step(self, optimizer, batch_idx, current_epoch, total_batch, **kwargs): + iter = float(batch_idx) / total_batch lr = self.base_lr * pow(self.base_rate, (current_epoch + iter) / self.base_duration) for pg in optimizer.param_groups: pg['lr'] = lr + + +@lr_schedule("adjusting_schedule") +def adjusting_schedule(optimizer, cfg:dict): + return Adjusting_Schedule(optimizer, **cfg) diff --git a/fgvclib/utils/lr_schedules/constant_schedule.py b/fgvclib/utils/lr_schedules/constant_schedule.py index 3ae9c14..6c2fdea 100644 --- a/fgvclib/utils/lr_schedules/constant_schedule.py +++ b/fgvclib/utils/lr_schedules/constant_schedule.py @@ -1,9 +1,7 @@ from torch.optim.lr_scheduler import LambdaLR -from .lr_schedule import LRSchedule - -class ConstantSchedule(LambdaLR, LRSchedule): +class ConstantSchedule(LambdaLR): """ Constant learning rate schedule. """ def __init__(self, optimizer, total_epoch=-1): diff --git a/fgvclib/utils/lr_schedules/cosine_anneal_schedule.py b/fgvclib/utils/lr_schedules/cosine_anneal_schedule.py index da135ee..6ca37cd 100644 --- a/fgvclib/utils/lr_schedules/cosine_anneal_schedule.py +++ b/fgvclib/utils/lr_schedules/cosine_anneal_schedule.py @@ -3,18 +3,22 @@ from .lr_schedule import LRSchedule from . import lr_schedule -@lr_schedule("cosine_anneal_schedule") + class CosineAnnealSchedule(LRSchedule): - def __init__(self, cfg) -> None: - super().__init__(cfg) - self.update_level = 'epoch_update' + def __init__(self, optimizer) -> None: + super().__init__(optimizer) + self.update_level = 'epoch' - def step(self, optimizer, current_epoch, total_epoch, **kwargs): + def step(self, current_epoch, total_epoch, **kwargs): cos_inner = np.pi * (current_epoch % (total_epoch)) cos_inner /= (total_epoch) cos_out = np.cos(cos_inner) + 1 - for pg in optimizer.param_groups: + for pg in self.optimizer.param_groups: current_lr = pg['lr'] pg['lr'] = float(current_lr / 2 * cos_out) + +@lr_schedule("cosine_anneal_schedule") +def cosine_anneal_schedule(optimizer, cfg:dict): + return CosineAnnealSchedule(optimizer, **cfg) diff --git a/fgvclib/utils/lr_schedules/cosine_decay_schedule.py b/fgvclib/utils/lr_schedules/cosine_decay_schedule.py index d2157ae..ef77d78 100644 --- a/fgvclib/utils/lr_schedules/cosine_decay_schedule.py +++ b/fgvclib/utils/lr_schedules/cosine_decay_schedule.py @@ -7,27 +7,15 @@ def cosine_decay(max_epochs, warmup_batchs, max_lr, batchs: int, decay_type: int iters = np.arange(total_batchs - warmup_batchs) if decay_type == 1: - schedule = np.array([1e-12 + 0.5 * (max_lr - 1e-12) * (1 + \ - math.cos(math.pi * t / total_batchs)) for t in - iters]) + schedule = np.array([1e-12 + 0.5 * (max_lr - 1e-12) * (1 + math.cos(math.pi * t / total_batchs)) for t in iters]) elif decay_type == 2: schedule = max_lr * np.array([math.cos(7 * math.pi * t / (16 * total_batchs)) for t in iters]) else: - raise ValueError("Not support this deccay type") + raise ValueError("Not support this decay type") if warmup_batchs > 0: warmup_lr_schedule = np.linspace(1e-9, max_lr, warmup_batchs) schedule = np.concatenate((warmup_lr_schedule, schedule)) return schedule - - -def get_lr(optimizer): - for param_group in optimizer.param_groups: - if param_group["lr"] is not None: - return param_group["lr"] - - -def adjust_lr(iteration, optimizer, schedule): - for param_group in optimizer.param_groups: - param_group["lr"] = schedule[iteration] + \ No newline at end of file diff --git a/fgvclib/utils/lr_schedules/lr_schedule.py b/fgvclib/utils/lr_schedules/lr_schedule.py index d995b8e..f31557b 100644 --- a/fgvclib/utils/lr_schedules/lr_schedule.py +++ b/fgvclib/utils/lr_schedules/lr_schedule.py @@ -1,7 +1,7 @@ class LRSchedule: - def __init__(self, cfg) -> None: - self.cfg = cfg + def __init__(self, optimizer) -> None: + self.optimizer = optimizer - def step(**kwargs): + def step(self): raise NotImplementedError("Eacbh subclass of LRSchedule should implemented the step method.") diff --git a/fgvclib/utils/lr_schedules/warmup_cosine_schedule.py b/fgvclib/utils/lr_schedules/warmup_cosine_schedule.py new file mode 100644 index 0000000..ddb8854 --- /dev/null +++ b/fgvclib/utils/lr_schedules/warmup_cosine_schedule.py @@ -0,0 +1,28 @@ +from torch.optim.lr_scheduler import LambdaLR + +from . import lr_schedule + + + +class WarmupLinearSchedule(LambdaLR): + + + def __init__(self, optimizer, warmup_steps, total_steps, last_epoch=-1): + self.warmup_steps = warmup_steps + self.total_steps = total_steps + super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) + self.update_level = "batch" + + def lr_lambda(self, step): + if step < self.warmup_steps: + return float(step) / float(max(1, self.warmup_steps)) + return max(0.0, float(self.total_steps - step) / float(max(1.0, self.total_steps - self.warmup_steps))) + + def step(self, **kwargs): + return super().step() + + +@lr_schedule("warmup_linear_schedule") +def warmup_linear_schedule(optimizer, cfg:dict): + return WarmupLinearSchedule(optimizer=optimizer, **cfg) + diff --git a/fgvclib/utils/update_function/general_update.py b/fgvclib/utils/update_function/general_update.py index 764f0b6..4b863d0 100644 --- a/fgvclib/utils/update_function/general_update.py +++ b/fgvclib/utils/update_function/general_update.py @@ -1,14 +1,16 @@ from typing import Iterable -import torch -import torch.nn as nn from torch.optim import Optimizer +import torch.nn as nn +from . import update_function from fgvclib.utils.update_strategy import get_update_strategy from fgvclib.utils.logger import Logger from fgvclib.utils.lr_schedules import LRSchedule + +@update_function("general_update") def general_update( model: nn.Module, optimizer: Optimizer, pbar:Iterable, lr_schedule:LRSchedule=None, strategy:str="general_updating", use_cuda:bool=True, logger:Logger=None, @@ -36,154 +38,15 @@ def general_update( losses_info.update({"mean_loss": mean_loss}) logger(losses_info, step=batch_idx) pbar.set_postfix(losses_info) - if lr_schedule.update_level == 'batch_update': - lr_schedule.step(optimizer=optimizer, batch_idx=batch_idx, batch_size=len(train_data), current_epoch=epoch, total_epoch=total_epoch) + if lr_schedule.update_level == 'batch': + lr_schedule.step(batch_idx=batch_idx, total_batch=len(pbar), current_epoch=epoch, total_epoch=total_epoch) - if lr_schedule.update_level == 'epoch_update': - lr_schedule.step(optimizer=optimizer, current_epoch=epoch, total_epoch=total_epoch) - - -def update_vitmodel(model: nn.Module, optimizer: Optimizer, scheduler, pbar: Iterable, - strategy: str = "vit_updating", - use_cuda: bool = True, logger: Logger = None): - r"""Update the FGVC model and record losses. - - Args: - model (nn.Module): The FGVC model. - optimizer (Optimizer): The Logger object. - scheduler : The scheduler strategy - pbar (Iterable): A iterable object provide training data. - strategy (string): The update strategy. - use_cuda (boolean): Whether to use GPU to train the model. - logger (Logger): The Logger object. - """ - model.train() - mean_loss = 0. - for batch_idx, train_data in enumerate(pbar): - losses_info = get_update_strategy(strategy)(model, train_data, optimizer, scheduler, use_cuda) - mean_loss = (mean_loss * batch_idx + losses_info['iter_loss']) / (batch_idx + 1) - losses_info.update({"mean_loss": mean_loss}) - logger(losses_info, step=batch_idx) - pbar.set_postfix(losses_info) - - -def update_swintModel(args, epoch, model, scaler, amp_context, optimizer, schedule, train_bar, logger: Logger = None): - optimizer.zero_grad() - total_batchs = len(train_bar) # just for log - show_progress = [x / 10 for x in range(11)] # just for log - progress_i = 0 - for batch_id, (datas, labels) in enumerate(train_bar): - model.train() - """ = = = = adjust learning rate = = = = """ - iterations = epoch * len(train_bar) + batch_id - adjust_lr(iterations, optimizer, schedule) - - batch_size = labels.size(0) - - """ = = = = forward and calculate loss = = = = """ - datas, labels = datas.cuda(), labels.cuda() - datas, labels = Variable(datas), Variable(labels) - with amp_context(): - """ - [Model Return] - FPN + Selector + Combiner --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) - 'preds_0', 'preds_1', 'comb_outs' - FPN + Selector --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) - 'preds_0', 'preds_1' - FPN --> return 'layer1', 'layer2', 'layer3', 'layer4' (depend on your setting) - ~ --> return 'ori_out' - - [Retuen Tensor] - 'preds_0': logit has not been selected by Selector. - 'preds_1': logit has been selected by Selector. - 'comb_outs': The prediction of combiner. - """ - outs = model(datas) - - loss = 0. - for name in outs: - - if "select_" in name: - if not args.use_selection: - raise ValueError("Selector not use here.") - if args.lambda_s != 0: - S = outs[name].size(1) - logit = outs[name].view(-1, args.num_classes).contiguous() - loss_s = nn.CrossEntropyLoss()(logit, - labels.unsqueeze(1).repeat(1, S).flatten(0)) - loss += args.lambda_s * loss_s - else: - loss_s = 0.0 - - elif "drop_" in name: - if not args.use_selection: - raise ValueError("Selector not use here.") - - if args.lambda_n != 0: - S = outs[name].size(1) - logit = outs[name].view(-1, args.num_classes).contiguous() - n_preds = nn.Tanh()(logit) - labels_0 = torch.zeros([batch_size * S, args.num_classes]) - 1 - labels_0 = labels_0.cuda() - loss_n = nn.MSELoss()(n_preds, labels_0) - loss += args.lambda_n * loss_n - else: - loss_n = 0.0 - - elif "layer" in name: - if not args.use_fpn: - raise ValueError("FPN not use here.") - if args.lambda_b != 0: - ### here using 'layer1'~'layer4' is default setting, you can change to your own - loss_b = nn.CrossEntropyLoss()(outs[name].mean(1), labels) - loss += args.lambda_b * loss_b - else: - loss_b = 0.0 - - elif "comb_outs" in name: - if not args.use_combiner: - raise ValueError("Combiner not use here.") - - if args.lambda_c != 0: - loss_c = nn.CrossEntropyLoss()(outs[name], labels) - loss += args.lambda_c * loss_c - - elif "ori_out" in name: - loss_ori = F.cross_entropy(outs[name], labels) - loss += loss_ori - - loss /= args.update_freq + if lr_schedule.update_level == 'epoch': + lr_schedule.step(current_epoch=epoch, total_epoch=total_epoch) - """ = = = = calculate gradient = = = = """ - if args.use_amp: - scaler.scale(loss).backward() - else: - loss.backward() - """ = = = = update model = = = = """ - if (batch_id + 1) % args.update_freq == 0: - if args.use_amp: - scaler.step(optimizer) - scaler.update() # next batch - else: - optimizer.step() - optimizer.zero_grad() - """ log """ - if (batch_id + 1) % args.log_freq == 0: - model.eval() - msg = {} - msg['info/epoch'] = epoch + 1 - msg['info/lr'] = get_lr(optimizer) - msg['loss'] = loss.item() - cal_train_metrics(args, msg, outs, labels, batch_size) - logger(msg) - train_progress = (batch_id + 1) / total_batchs - # print(train_progress, show_progress[progress_i]) - if train_progress > show_progress[progress_i]: - print(".." + str(int(show_progress[progress_i] * 100)) + "%", end='', flush=True) - progress_i += 1 \ No newline at end of file diff --git a/fgvclib/utils/update_function/update_swint.py b/fgvclib/utils/update_function/update_swint.py new file mode 100644 index 0000000..5386b0a --- /dev/null +++ b/fgvclib/utils/update_function/update_swint.py @@ -0,0 +1,233 @@ +from typing import Iterable +from torch.autograd import Variable +from torch.optim import Optimizer +from torch.cuda.amp import autocast, GradScaler +import torch.nn as nn +import torch +import torch.functional as F + +from . import update_function +from fgvclib.utils.logger import Logger + + + +@update_function("update_swint") +def update_swint(model: nn.Module, optimizer: Optimizer, pbar:Iterable, lr_schedule=None, + strategy:str="update_swint", use_cuda:bool=True, logger:Logger=None, + epoch:int=None, total_epoch:int=None, amp:bool=False, use_selection=False, **kwargs, +): + optimizer.zero_grad() + total_batchs = len(pbar) # just for log + show_progress = [x / 10 for x in range(11)] # just for log + progress_i = 0 + for batch_id, (datas, labels) in enumerate(pbar): + model.train() + """ = = = = adjust learning rate = = = = """ + iterations = epoch * len(pbar) + batch_id + + for param_group in optimizer.param_groups: + param_group["lr"] = lr_schedule[iterations] + + batch_size = labels.size(0) + + """ = = = = forward and calculate loss = = = = """ + datas, labels = datas.cuda(), labels.cuda() + datas, labels = Variable(datas), Variable(labels) + with autocast(): + """ + [Model Return] + FPN + Selector + Combiner --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1', 'comb_outs' + FPN + Selector --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1' + FPN --> return 'layer1', 'layer2', 'layer3', 'layer4' (depend on your setting) + ~ --> return 'ori_out' + + [Retuen Tensor] + 'preds_0': logit has not been selected by Selector. + 'preds_1': logit has been selected by Selector. + 'comb_outs': The prediction of combiner. + """ + outs = model(datas) + + loss = 0. + for name in outs: + + if "select_" in name: + if not use_selection: + raise ValueError("Selector not use here.") + if args.lambda_s != 0: + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + loss_s = nn.CrossEntropyLoss()(logit, + labels.unsqueeze(1).repeat(1, S).flatten(0)) + loss += args.lambda_s * loss_s + else: + loss_s = 0.0 + + elif "drop_" in name: + if not use_selection: + raise ValueError("Selector not use here.") + + if args.lambda_n != 0: + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + n_preds = nn.Tanh()(logit) + labels_0 = torch.zeros([batch_size * S, args.num_classes]) - 1 + labels_0 = labels_0.cuda() + loss_n = nn.MSELoss()(n_preds, labels_0) + loss += args.lambda_n * loss_n + else: + loss_n = 0.0 + + elif "layer" in name: + if not args.use_fpn: + raise ValueError("FPN not use here.") + if args.lambda_b != 0: + ### here using 'layer1'~'layer4' is default setting, you can change to your own + loss_b = nn.CrossEntropyLoss()(outs[name].mean(1), labels) + loss += args.lambda_b * loss_b + else: + loss_b = 0.0 + + elif "comb_outs" in name: + if not args.use_combiner: + raise ValueError("Combiner not use here.") + + if args.lambda_c != 0: + loss_c = nn.CrossEntropyLoss()(outs[name], labels) + loss += args.lambda_c * loss_c + + elif "ori_out" in name: + loss_ori = F.cross_entropy(outs[name], labels) + loss += loss_ori + + loss /= args.update_freq + + """ = = = = calculate gradient = = = = """ + if amp: + scaler.scale(loss).backward() + else: + loss.backward() + + """ = = = = update model = = = = """ + if (batch_id + 1) % args.update_freq == 0: + if args.use_amp: + scaler.step(optimizer) + scaler.update() # next batch + else: + optimizer.step() + optimizer.zero_grad() + + """ log """ + if (batch_id + 1) % args.log_freq == 0: + model.eval() + msg = {} + msg['info/epoch'] = epoch + 1 + msg['info/lr'] = get_lr(optimizer) + msg['loss'] = loss.item() + cal_train_metrics(args, msg, outs, labels, batch_size) + logger(msg) + + train_progress = (batch_id + 1) / total_batchs + # print(train_progress, show_progress[progress_i]) + if train_progress > show_progress[progress_i]: + print(".." + str(int(show_progress[progress_i] * 100)) + "%", end='', flush=True) + progress_i += 1 + + + + + + +@torch.no_grad() +def cal_train_metrics(msg: dict, outs: dict, labels: torch.Tensor, batch_size: int, use_fpn, use_selection, use_combiner): + """ + only present top-1 training accuracy + """ + + total_loss = 0.0 + + if use_fpn: + for i in range(1, 5): + acc = top_k_corrects(outs["layer"+str(i)].mean(1), labels, tops=[1])["top-1"] / batch_size + acc = round(acc * 100, 2) + msg["train_acc/layer{}_acc".format(i)] = acc + loss = F.cross_entropy(outs["layer"+str(i)].mean(1), labels) + msg["train_loss/layer{}_loss".format(i)] = loss.item() + total_loss += loss.item() + + if use_selection: + for name in outs: + if "select_" not in name: + continue + B, S, _ = outs[name].size() + logit = outs[name].view(-1, args.num_classes) + labels_0 = labels.unsqueeze(1).repeat(1, S).flatten(0) + acc = top_k_corrects(logit, labels_0, tops=[1])["top-1"] / (B*S) + acc = round(acc * 100, 2) + msg["train_acc/{}_acc".format(name)] = acc + labels_0 = torch.zeros([B * S, num_classes]) - 1 + labels_0 = labels_0.cuda() + loss = F.mse_loss(F.tanh(logit), labels_0) + msg["train_loss/{}_loss".format(name)] = loss.item() + total_loss += loss.item() + + for name in outs: + if "drop_" not in name: + continue + B, S, _ = outs[name].size() + logit = outs[name].view(-1, num_classes) + labels_1 = labels.unsqueeze(1).repeat(1, S).flatten(0) + acc = top_k_corrects(logit, labels_1, tops=[1])["top-1"] / (B*S) + acc = round(acc * 100, 2) + msg["train_acc/{}_acc".format(name)] = acc + loss = F.cross_entropy(logit, labels_1) + msg["train_loss/{}_loss".format(name)] = loss.item() + total_loss += loss.item() + + if use_combiner: + acc = top_k_corrects(outs['comb_outs'], labels, tops=[1])["top-1"] / batch_size + acc = round(acc * 100, 2) + msg["train_acc/combiner_acc"] = acc + loss = F.cross_entropy(outs['comb_outs'], labels) + msg["train_loss/combiner_loss"] = loss.item() + total_loss += loss.item() + + if "ori_out" in outs: + acc = top_k_corrects(outs["ori_out"], labels, tops=[1])["top-1"] / batch_size + acc = round(acc * 100, 2) + msg["train_acc/ori_acc"] = acc + loss = F.cross_entropy(outs["ori_out"], labels) + msg["train_loss/ori_loss"] = loss.item() + total_loss += loss.item() + + msg["train_loss/total_loss"] = total_loss + + +@torch.no_grad() +def top_k_corrects(preds: torch.Tensor, labels: torch.Tensor, tops: list = [1, 3, 5]): + """ + preds: [B, C] (C is num_classes) + labels: [B, ] + """ + if preds.device != torch.device('cpu'): + preds = preds.cpu() + if labels.device != torch.device('cpu'): + labels = labels.cpu() + tmp_cor = 0 + corrects = {"top-"+str(x):0 for x in tops} + sorted_preds = torch.sort(preds, dim=-1, descending=True)[1] + for i in range(tops[-1]): + tmp_cor += sorted_preds[:, i].eq(labels).sum().item() + # records + if "top-"+str(i+1) in corrects: + corrects["top-"+str(i+1)] = tmp_cor + return corrects + +def get_lr(optimizer): + for param_group in optimizer.param_groups: + if param_group["lr"] is not None: + return param_group["lr"] + + diff --git a/fgvclib/utils/update_strategy/__init__.py b/fgvclib/utils/update_strategy/__init__.py index 66fc984..02f2c26 100644 --- a/fgvclib/utils/update_strategy/__init__.py +++ b/fgvclib/utils/update_strategy/__init__.py @@ -1,22 +1,34 @@ -from .progressive_updating_with_jigsaw import progressive_updating_with_jigsaw -from .progressive_updating_consistency_constraint import progressive_updating_consistency_constraint -from .general_updating import general_updating +import os +import importlib -__all__ = [ - 'progressive_updating_with_jigsaw', 'progressive_updating_consistency_constraint', 'general_updating' -] +from fgvclib.utils.lr_schedules.lr_schedule import LRSchedule +from torch.optim import SGD -def get_update_strategy(strategy_name): - r"""Return the learning rate schedule with the given name. +__UPDATE_ST_DICT__ = {} + +def get_update_strategy(name) -> LRSchedule: + r"""Return the dataset with the given name. Args: - strategy_name (str): - The name of the update strategy. - + name (str): + The name of dataset. Return: - The update strategy contructor method. + (FGVCDataset): The dataset contructor method. """ + + return __UPDATE_ST_DICT__[name] + +def update_strategy(name): + + def register_function_fn(cls): + if name in __UPDATE_ST_DICT__: + raise ValueError("Name %s already registered!" % name) + __UPDATE_ST_DICT__[name] = cls + return cls + + return register_function_fn - if strategy_name not in globals(): - raise NotImplementedError(f"Strategy not found: {strategy_name}\nAvailable strategy: {__all__}") - return globals()[strategy_name] +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith('.py') and not file.startswith('_'): + module_name = file[:file.find('.py')] + module = importlib.import_module('fgvclib.utils.update_strategy.' + module_name) \ No newline at end of file diff --git a/fgvclib/utils/update_strategy/general_updating.py b/fgvclib/utils/update_strategy/general.py similarity index 88% rename from fgvclib/utils/update_strategy/general_updating.py rename to fgvclib/utils/update_strategy/general.py index 17d5694..30c54ef 100644 --- a/fgvclib/utils/update_strategy/general_updating.py +++ b/fgvclib/utils/update_strategy/general.py @@ -2,10 +2,13 @@ import typing as t from torch.cuda.amp import autocast, GradScaler +from . import update_strategy from fgvclib.criterions import compute_loss_value, detach_loss_value -def general_updating(model, train_data, optimizer, use_cuda=True, amp=False, **kwargs) -> t.Dict: + +@update_strategy("general_strategy") +def general_strategy(model, train_data, optimizer, use_cuda=True, amp=False, **kwargs) -> t.Dict: if amp: scaler = GradScaler() inputs, targets = train_data diff --git a/fgvclib/utils/update_strategy/progressive_updating_consistency_constraint.py b/fgvclib/utils/update_strategy/progressive_consistency_constraint.py similarity index 93% rename from fgvclib/utils/update_strategy/progressive_updating_consistency_constraint.py rename to fgvclib/utils/update_strategy/progressive_consistency_constraint.py index b06f151..3ff53f4 100644 --- a/fgvclib/utils/update_strategy/progressive_updating_consistency_constraint.py +++ b/fgvclib/utils/update_strategy/progressive_consistency_constraint.py @@ -3,11 +3,11 @@ from torch import nn, Tensor from torch.autograd import Variable +from . import update_strategy from fgvclib.criterions import compute_loss_value, detach_loss_value -BLOCKS = [[8, 8, 0, 0], [4, 4, 4, 0], [2, 2, 2, 2]] -alpha = [0.01, 0.05, 0.1] +@update_strategy("progressive_updating_consistency_constraint") def progressive_updating_consistency_constraint(model:nn.Module, train_data:t.Tuple[Tensor, Tensor, Tensor], optimizer, use_cuda=True, **kwargs) -> t.Dict: inputs, positive_inputs, targets = train_data batch_size = inputs.size(0) diff --git a/fgvclib/utils/update_strategy/progressive_updating_with_jigsaw.py b/fgvclib/utils/update_strategy/progressive_updating_with_jigsaw.py index c5e13b6..d7d4f42 100644 --- a/fgvclib/utils/update_strategy/progressive_updating_with_jigsaw.py +++ b/fgvclib/utils/update_strategy/progressive_updating_with_jigsaw.py @@ -1,8 +1,11 @@ import random from torch.autograd import Variable +from . import update_strategy from fgvclib.criterions import compute_loss_value, detach_loss_value + +@update_strategy("update_strategy") def progressive_updating_with_jigsaw(model, train_data, optimizer, use_cuda=True, **kwargs): inputs, targets = train_data if use_cuda: diff --git a/fgvclib/utils/update_strategy/vit_updating.py b/fgvclib/utils/update_strategy/vit_updating.py index b393e69..4004c87 100644 --- a/fgvclib/utils/update_strategy/vit_updating.py +++ b/fgvclib/utils/update_strategy/vit_updating.py @@ -2,10 +2,14 @@ import torch from torch.autograd import Variable import typing as t + +from ..update_strategy import update_strategy from fgvclib.criterions import compute_loss_value, detach_loss_value -def vit_updating(model, train_data, optimizer, scheduler, use_cuda=True) -> t.Dict: +@update_strategy("vit_update_strategy") +def vit_update_strategy(model, train_data, optimizer, use_cuda=True, amp=False, **kwargs) -> t.Dict: + inputs, targets = train_data if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() @@ -16,7 +20,6 @@ def vit_updating(model, train_data, optimizer, scheduler, use_cuda=True) -> t.Di total_loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() - scheduler.step() optimizer.zero_grad() losses_info = detach_loss_value(losses) losses_info.update({"iter_loss": total_loss.item()}) diff --git a/main.py b/main.py index 52eaa7b..7a03241 100644 --- a/main.py +++ b/main.py @@ -81,7 +81,9 @@ def train(cfg: CfgNode): metrics = build_metrics(cfg.METRICS) - lr_schedule = build_lr_schedule(cfg.LR_SCHEDULE) + lr_schedule = build_lr_schedule(optimizer, cfg.LR_SCHEDULE) + + update_fn = build_update_function(cfg) for epoch in range(cfg.START_EPOCH, cfg.EPOCH_NUM): if args.distributed: @@ -91,7 +93,7 @@ def train(cfg: CfgNode): logger(f'Epoch: {epoch + 1} / {cfg.EPOCH_NUM} Training') - general_update( + update_fn( model, optimizer, train_bar, strategy=cfg.UPDATE_STRATEGY, use_cuda=cfg.USE_CUDA, lr_schedule=lr_schedule, logger=logger, epoch=epoch, total_epoch=cfg.EPOCH_NUM, amp=cfg.AMP From 1c29de2d191de055e84bd042fca9ae7b97d4eb99 Mon Sep 17 00:00:00 2001 From: xin-ran-w Date: Tue, 17 Jan 2023 21:23:18 +0800 Subject: [PATCH 03/10] update config --- configs/mutual_channel_loss/mcl_vgg16.yml | 2 +- .../pmg_v2_resnet50.yml | 10 ++ configs/swin_transformer/swin_transformer.yml | 1 + fgvclib/models/sotas/pmg_v2.py | 12 +- fgvclib/models/sotas/swin_T.py | 4 +- .../utils/update_function/general_update.py | 2 +- fgvclib/utils/update_function/update_swint.py | 160 ++++-------------- .../{general.py => general_strategy.py} | 0 .../progressive_updating_with_jigsaw.py | 2 +- 9 files changed, 52 insertions(+), 141 deletions(-) rename fgvclib/utils/update_strategy/{general.py => general_strategy.py} (100%) diff --git a/configs/mutual_channel_loss/mcl_vgg16.yml b/configs/mutual_channel_loss/mcl_vgg16.yml index 675c961..6d776e4 100644 --- a/configs/mutual_channel_loss/mcl_vgg16.yml +++ b/configs/mutual_channel_loss/mcl_vgg16.yml @@ -112,7 +112,7 @@ OPTIMIZER: ITERATION_NUM: ~ EPOCH_NUM: 200 START_EPOCH: 0 -UPDATE_STRATEGY: "general_updating" +UPDATE_STRATEGY: "general_strategy" # Validation details PER_ITERATION: ~ diff --git a/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml b/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml index 1b76992..7bf908f 100644 --- a/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml +++ b/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml @@ -28,6 +28,16 @@ DATASET: MODEL: NAME: "PMG_V2" CLASS_NUM: 200 + ARGS: + - outputs_num: 3 + - BLOCKS: + - [8, 8, 0, 0] + - [4, 4, 4, 0] + - [2, 2, 2, 2] + - alpha: + - 0.01 + - 0.05 + - 0.1 CRITERIONS: - name: "cross_entropy_loss" args: [] diff --git a/configs/swin_transformer/swin_transformer.yml b/configs/swin_transformer/swin_transformer.yml index ac252c3..0b3c4b1 100644 --- a/configs/swin_transformer/swin_transformer.yml +++ b/configs/swin_transformer/swin_transformer.yml @@ -30,6 +30,7 @@ MODEL: CLASS_NUM: 200 ARGS: - img_size: 384 + - fpn_size: 1536 CRITERIONS: - name: "cross_entropy_loss" diff --git a/fgvclib/models/sotas/pmg_v2.py b/fgvclib/models/sotas/pmg_v2.py index b124c7b..10d58b5 100644 --- a/fgvclib/models/sotas/pmg_v2.py +++ b/fgvclib/models/sotas/pmg_v2.py @@ -1,4 +1,5 @@ import torch.nn as nn +from yacs.config import CfgNode from fgvclib.models.sotas.sota import FGVCSOTA from fgvclib.criterions import LossItem @@ -16,11 +17,12 @@ class PMG_V2(FGVCSOTA): BLOCKS = [[8, 8, 0, 0], [4, 4, 4, 0], [2, 2, 2, 2]] alpha = [0.01, 0.05, 0.1] - def __init__(self, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): - super().__init__(backbone, encoder, necks, heads, criterions) - - self.outputs_num = 3 - + def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): + super().__init__(cfg, backbone, encoder, necks, heads, criterions) + self.BLOCKS = self.args["BLOCKS"] + self.outputs_num = self.args["outputs_num"] + self.alpha = self.args["alpha"] + def forward(self, x, targets=None, step:int=None, batch_size:int=None): if step is not None: diff --git a/fgvclib/models/sotas/swin_T.py b/fgvclib/models/sotas/swin_T.py index 22f49e9..f35cff5 100644 --- a/fgvclib/models/sotas/swin_T.py +++ b/fgvclib/models/sotas/swin_T.py @@ -17,10 +17,10 @@ def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: num_classes = cfg.CLASS_NUM rand_in = torch.randn(1, 3, img_size, img_size) outs = self.backbone(rand_in) - + fpn_size = self.args["fpn_size"] ### = = = = = FPN = = = = = self.fpn = self.encoder - fpn_size = 1536 + self.build_fpn_classifier(outs, fpn_size, num_classes) ### = = = = = Selector = = = = = diff --git a/fgvclib/utils/update_function/general_update.py b/fgvclib/utils/update_function/general_update.py index 4b863d0..b6f695a 100644 --- a/fgvclib/utils/update_function/general_update.py +++ b/fgvclib/utils/update_function/general_update.py @@ -33,7 +33,7 @@ def general_update( model.train() mean_loss = 0. for batch_idx, train_data in enumerate(pbar): - losses_info = get_update_strategy(strategy)(model, train_data, optimizer, use_cuda, amp) + losses_info = get_update_strategy(strategy)(model=model, train_data=train_data, optimizer=optimizer, use_cuda=use_cuda, amp=amp) mean_loss = (mean_loss * batch_idx + losses_info['iter_loss']) / (batch_idx + 1) losses_info.update({"mean_loss": mean_loss}) logger(losses_info, step=batch_idx) diff --git a/fgvclib/utils/update_function/update_swint.py b/fgvclib/utils/update_function/update_swint.py index 5386b0a..f69109c 100644 --- a/fgvclib/utils/update_function/update_swint.py +++ b/fgvclib/utils/update_function/update_swint.py @@ -14,8 +14,18 @@ @update_function("update_swint") def update_swint(model: nn.Module, optimizer: Optimizer, pbar:Iterable, lr_schedule=None, strategy:str="update_swint", use_cuda:bool=True, logger:Logger=None, - epoch:int=None, total_epoch:int=None, amp:bool=False, use_selection=False, **kwargs, -): + epoch:int=None, total_epoch:int=None, amp:bool=False, use_selection=False, cfg=None, **kwargs, +): + scaler = GradScaler() + lambda_s = cfg.lambda_s + lambda_n = cfg.lambda_n + num_classes = cfg.MODEL.CLASS_NUM + use_fpn = cfg.MODEL.ARGS["use_fpn"] + lambda_b = cfg.lambda_b + use_combiner = cfg.use_combiner + lambda_c = cfg.lambda_c + update_freq = cfg.update_freq + optimizer.zero_grad() total_batchs = len(pbar) # just for log show_progress = [x / 10 for x in range(11)] # just for log @@ -56,12 +66,12 @@ def update_swint(model: nn.Module, optimizer: Optimizer, pbar:Iterable, lr_sched if "select_" in name: if not use_selection: raise ValueError("Selector not use here.") - if args.lambda_s != 0: + if lambda_s != 0: S = outs[name].size(1) - logit = outs[name].view(-1, args.num_classes).contiguous() + logit = outs[name].view(-1, num_classes).contiguous() loss_s = nn.CrossEntropyLoss()(logit, labels.unsqueeze(1).repeat(1, S).flatten(0)) - loss += args.lambda_s * loss_s + loss += lambda_s * loss_s else: loss_s = 0.0 @@ -69,40 +79,40 @@ def update_swint(model: nn.Module, optimizer: Optimizer, pbar:Iterable, lr_sched if not use_selection: raise ValueError("Selector not use here.") - if args.lambda_n != 0: + if lambda_n != 0: S = outs[name].size(1) - logit = outs[name].view(-1, args.num_classes).contiguous() + logit = outs[name].view(-1, num_classes).contiguous() n_preds = nn.Tanh()(logit) - labels_0 = torch.zeros([batch_size * S, args.num_classes]) - 1 + labels_0 = torch.zeros([batch_size * S, num_classes]) - 1 labels_0 = labels_0.cuda() loss_n = nn.MSELoss()(n_preds, labels_0) - loss += args.lambda_n * loss_n + loss += lambda_n * loss_n else: loss_n = 0.0 elif "layer" in name: - if not args.use_fpn: + if not use_fpn: raise ValueError("FPN not use here.") - if args.lambda_b != 0: + if lambda_b != 0: ### here using 'layer1'~'layer4' is default setting, you can change to your own loss_b = nn.CrossEntropyLoss()(outs[name].mean(1), labels) - loss += args.lambda_b * loss_b + loss += lambda_b * loss_b else: loss_b = 0.0 elif "comb_outs" in name: - if not args.use_combiner: + if not use_combiner: raise ValueError("Combiner not use here.") - if args.lambda_c != 0: + if lambda_c != 0: loss_c = nn.CrossEntropyLoss()(outs[name], labels) - loss += args.lambda_c * loss_c + loss += lambda_c * loss_c elif "ori_out" in name: loss_ori = F.cross_entropy(outs[name], labels) loss += loss_ori - loss /= args.update_freq + loss /= update_freq """ = = = = calculate gradient = = = = """ if amp: @@ -111,123 +121,11 @@ def update_swint(model: nn.Module, optimizer: Optimizer, pbar:Iterable, lr_sched loss.backward() """ = = = = update model = = = = """ - if (batch_id + 1) % args.update_freq == 0: - if args.use_amp: + if (batch_id + 1) % update_freq == 0: + if amp: scaler.step(optimizer) scaler.update() # next batch else: optimizer.step() optimizer.zero_grad() - - """ log """ - if (batch_id + 1) % args.log_freq == 0: - model.eval() - msg = {} - msg['info/epoch'] = epoch + 1 - msg['info/lr'] = get_lr(optimizer) - msg['loss'] = loss.item() - cal_train_metrics(args, msg, outs, labels, batch_size) - logger(msg) - - train_progress = (batch_id + 1) / total_batchs - # print(train_progress, show_progress[progress_i]) - if train_progress > show_progress[progress_i]: - print(".." + str(int(show_progress[progress_i] * 100)) + "%", end='', flush=True) - progress_i += 1 - - - - - - -@torch.no_grad() -def cal_train_metrics(msg: dict, outs: dict, labels: torch.Tensor, batch_size: int, use_fpn, use_selection, use_combiner): - """ - only present top-1 training accuracy - """ - - total_loss = 0.0 - - if use_fpn: - for i in range(1, 5): - acc = top_k_corrects(outs["layer"+str(i)].mean(1), labels, tops=[1])["top-1"] / batch_size - acc = round(acc * 100, 2) - msg["train_acc/layer{}_acc".format(i)] = acc - loss = F.cross_entropy(outs["layer"+str(i)].mean(1), labels) - msg["train_loss/layer{}_loss".format(i)] = loss.item() - total_loss += loss.item() - - if use_selection: - for name in outs: - if "select_" not in name: - continue - B, S, _ = outs[name].size() - logit = outs[name].view(-1, args.num_classes) - labels_0 = labels.unsqueeze(1).repeat(1, S).flatten(0) - acc = top_k_corrects(logit, labels_0, tops=[1])["top-1"] / (B*S) - acc = round(acc * 100, 2) - msg["train_acc/{}_acc".format(name)] = acc - labels_0 = torch.zeros([B * S, num_classes]) - 1 - labels_0 = labels_0.cuda() - loss = F.mse_loss(F.tanh(logit), labels_0) - msg["train_loss/{}_loss".format(name)] = loss.item() - total_loss += loss.item() - - for name in outs: - if "drop_" not in name: - continue - B, S, _ = outs[name].size() - logit = outs[name].view(-1, num_classes) - labels_1 = labels.unsqueeze(1).repeat(1, S).flatten(0) - acc = top_k_corrects(logit, labels_1, tops=[1])["top-1"] / (B*S) - acc = round(acc * 100, 2) - msg["train_acc/{}_acc".format(name)] = acc - loss = F.cross_entropy(logit, labels_1) - msg["train_loss/{}_loss".format(name)] = loss.item() - total_loss += loss.item() - - if use_combiner: - acc = top_k_corrects(outs['comb_outs'], labels, tops=[1])["top-1"] / batch_size - acc = round(acc * 100, 2) - msg["train_acc/combiner_acc"] = acc - loss = F.cross_entropy(outs['comb_outs'], labels) - msg["train_loss/combiner_loss"] = loss.item() - total_loss += loss.item() - - if "ori_out" in outs: - acc = top_k_corrects(outs["ori_out"], labels, tops=[1])["top-1"] / batch_size - acc = round(acc * 100, 2) - msg["train_acc/ori_acc"] = acc - loss = F.cross_entropy(outs["ori_out"], labels) - msg["train_loss/ori_loss"] = loss.item() - total_loss += loss.item() - - msg["train_loss/total_loss"] = total_loss - - -@torch.no_grad() -def top_k_corrects(preds: torch.Tensor, labels: torch.Tensor, tops: list = [1, 3, 5]): - """ - preds: [B, C] (C is num_classes) - labels: [B, ] - """ - if preds.device != torch.device('cpu'): - preds = preds.cpu() - if labels.device != torch.device('cpu'): - labels = labels.cpu() - tmp_cor = 0 - corrects = {"top-"+str(x):0 for x in tops} - sorted_preds = torch.sort(preds, dim=-1, descending=True)[1] - for i in range(tops[-1]): - tmp_cor += sorted_preds[:, i].eq(labels).sum().item() - # records - if "top-"+str(i+1) in corrects: - corrects["top-"+str(i+1)] = tmp_cor - return corrects - -def get_lr(optimizer): - for param_group in optimizer.param_groups: - if param_group["lr"] is not None: - return param_group["lr"] - - + \ No newline at end of file diff --git a/fgvclib/utils/update_strategy/general.py b/fgvclib/utils/update_strategy/general_strategy.py similarity index 100% rename from fgvclib/utils/update_strategy/general.py rename to fgvclib/utils/update_strategy/general_strategy.py diff --git a/fgvclib/utils/update_strategy/progressive_updating_with_jigsaw.py b/fgvclib/utils/update_strategy/progressive_updating_with_jigsaw.py index d7d4f42..8eb268b 100644 --- a/fgvclib/utils/update_strategy/progressive_updating_with_jigsaw.py +++ b/fgvclib/utils/update_strategy/progressive_updating_with_jigsaw.py @@ -5,7 +5,7 @@ from fgvclib.criterions import compute_loss_value, detach_loss_value -@update_strategy("update_strategy") +@update_strategy("progressive_updating_with_jigsaw") def progressive_updating_with_jigsaw(model, train_data, optimizer, use_cuda=True, **kwargs): inputs, targets = train_data if use_cuda: From 7497205ae2792aa11ec53d8419cf02f7745dd8ac Mon Sep 17 00:00:00 2001 From: xin-ran-w Date: Wed, 18 Jan 2023 16:41:09 +0800 Subject: [PATCH 04/10] update swintransformer backbone --- configs/swin_transformer/swin_transformer.yml | 42 +- fgvclib/models/backbones/swin_t.py | 7 - fgvclib/models/backbones/swin_transformer.py | 705 ++++++++++++++++++ fgvclib/models/heads/gcn_combiner.py | 2 +- fgvclib/models/sotas/swin_T.py | 109 ++- fgvclib/transforms/base_transforms.py | 8 + 6 files changed, 835 insertions(+), 38 deletions(-) create mode 100644 fgvclib/models/backbones/swin_transformer.py diff --git a/configs/swin_transformer/swin_transformer.yml b/configs/swin_transformer/swin_transformer.yml index 0b3c4b1..259d521 100644 --- a/configs/swin_transformer/swin_transformer.yml +++ b/configs/swin_transformer/swin_transformer.yml @@ -31,6 +31,19 @@ MODEL: ARGS: - img_size: 384 - fpn_size: 1536 + - lambda_s: 0.0 + - lambda_n: 5.0 + - lambda_b: 0.5 + - lambda_c: 1.0 + - update_freq: 2 + - use_selection: True + - use_fpn: True + - use_combiner: True + - num_select: + - layer1: 2048 + - layer2: 512 + - layer3: 128 + - layer4: 32 CRITERIONS: - name: "cross_entropy_loss" @@ -40,25 +53,13 @@ MODEL: args: [] w: 1.0 BACKBONE: - NAME: "swint" + NAME: "swin_large_patch4_window12_384_in22k" ARGS: - pretrained: True ENCODER: - NAME: "fpn" - ARGS: - - fpn_size: 1536 - - proj_type: "Linear" - - upsample_type: "Conv" + NAME: ~ NECKS: - NAME: "weakly_selector" - ARGS: - - num_classes: 200 - - num_selects: - - layer1: 2048 - - layer2: 512 - - layer3: 128 - - layer4: 32 - - fpn_size: 1536 + NAME: ~ HEADS: NAME: "GCN_combiner" ARGS: @@ -125,12 +126,17 @@ OPTIMIZER: heads: 0.005 WEIGHT_DECAY: 0.0005 +LR_SCHEDULE: + NAME: "warmup_linear_schedule" + ARGS: + - warmup_steps: 800 + - total_steps: 10000 + - MAX_LR: 0.005 + ITERATION_NUM: ~ -WARMUP_BATCHS: 800 -MAX_LR: 0.005 EPOCH_NUM: 50 START_EPOCH: 0 -USE_AMP: True +AMP: True UPDATE_STRATEGY: ~ diff --git a/fgvclib/models/backbones/swin_t.py b/fgvclib/models/backbones/swin_t.py index 0151722..e684f5c 100644 --- a/fgvclib/models/backbones/swin_t.py +++ b/fgvclib/models/backbones/swin_t.py @@ -27,10 +27,3 @@ def swint(cfg): backbone = timm.create_model('swin_large_patch4_window12_384_in22k', pretrained=cfg['pretrained']) backbone.train() return backbone - - -def vit16(cfg): - backbone = timm.create_model('vit_base_patch16_224_miil_in21k', pretrained=False) - backbone = load_model_weights(backbone, cfg['pretrained']) - backbone.train() - return backbone diff --git a/fgvclib/models/backbones/swin_transformer.py b/fgvclib/models/backbones/swin_transformer.py new file mode 100644 index 0000000..6f5ce1f --- /dev/null +++ b/fgvclib/models/backbones/swin_transformer.py @@ -0,0 +1,705 @@ +""" Swin Transformer +A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` + - https://arxiv.org/pdf/2103.14030 + +Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below + +""" +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- +import logging +import math +from copy import deepcopy +from typing import Optional + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.helpers import build_model_with_cfg +from timm.models.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, lecun_normal_ +from timm.models.vision_transformer import checkpoint_filter_fn + +from . import backbone + +_logger = logging.getLogger(__name__) + + +def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): + """ ViT weight initialization + * When called without n, head_bias, jax_impl args it will behave exactly the same + as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). + * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl + """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + elif name.startswith('pre_logits'): + lecun_normal_(module.weight) + nn.init.zeros_(module.bias) + else: + if jax_impl: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + else: + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif jax_impl and isinstance(module, nn.Conv2d): + # NOTE conv was left to pytorch default in my original init + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): + nn.init.zeros_(module.bias) + nn.init.ones_(module.weight) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models (my experiments) + 'swin_base_patch4_window12_384': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_base_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth', + ), + + 'swin_large_patch4_window12_384': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_large_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth', + ), + + 'swin_small_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth', + ), + + 'swin_tiny_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth', + ), + + 'swin_base_patch4_window12_384_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth', + input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841), + + 'swin_base_patch4_window7_224_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth', + num_classes=21841), + + 'swin_large_patch4_window12_384_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth', + input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841), + + 'swin_large_patch4_window7_224_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth', + num_classes=21841), + +} + + +def window_partition(x, window_size: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, + attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + # H *= 2 + # W *= 2 + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + # H *= 2 + # W *= 2 + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if not torch.jit.is_scripting() and self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), + window_size=7, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, weight_init='', **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + self.patch_grid = self.patch_embed.grid_size + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + else: + self.absolute_pos_embed = None + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + layers = [] + for i_layer in range(self.num_layers): + layers += [BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + input_resolution=(self.patch_grid[0] // (2 ** i_layer), self.patch_grid[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + ] + self.layers = nn.Sequential(*layers) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. + if weight_init.startswith('jax'): + for n, m in self.named_modules(): + _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True) + else: + self.apply(_init_vit_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + if self.absolute_pos_embed is not None: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + l1 = self.layers[0](x) + l2 = self.layers[1](l1) + l3 = self.layers[2](l2) + l4 = self.layers[3](l3) + # x = self.norm(l4) # B L C + # x = self.avgpool(x.transpose(1, 2)) # B C 1 + # x = torch.flatten(x, 1) + return {"layer1":l1, "layer2":l2, "layer3":l3, "layer4":l4} + + def forward(self, x): + x = self.forward_features(x) + # x = self.head(x) + return x + +def overlay_external_default_cfg(default_cfg, kwargs): + """ Overlay 'external_default_cfg' in kwargs on top of default_cfg arg. + """ + external_default_cfg = kwargs.pop('external_default_cfg', None) + if external_default_cfg: + default_cfg.pop('url', None) # url should come from external cfg + default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg + default_cfg.update(external_default_cfg) + + +def _create_swin_transformer(variant, cfg, default_cfg=None, **kwargs): + pretrained = False if "pretrained" not in cfg.keys() else cfg['pretrained'] + if default_cfg is None: + default_cfg = deepcopy(default_cfgs[variant]) + overlay_external_default_cfg(default_cfg, kwargs) + default_num_classes = default_cfg['num_classes'] + default_img_size = default_cfg['input_size'][-2:] + + num_classes = kwargs.pop('num_classes', default_num_classes) + img_size = kwargs.pop('img_size', default_img_size) + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + SwinTransformer, variant, pretrained, + default_cfg=default_cfg, + img_size=img_size, + num_classes=num_classes, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + + return model + + +@backbone('swin_base_patch4_window12_384') +def swin_base_patch4_window12_384(cfg, **kwargs): + """ Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer('swin_base_patch4_window12_384', cfg, **model_kwargs) + + +@backbone('swin_base_patch4_window7_224') +def swin_base_patch4_window7_224(cfg, **kwargs): + """ Swin-B @ 224x224, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer('swin_base_patch4_window7_224', cfg, **model_kwargs) + + +@backbone('swin_large_patch4_window12_384') +def swin_large_patch4_window12_384(cfg, **kwargs): + """ Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer('swin_large_patch4_window12_384', cfg, **model_kwargs) + + +@backbone('swin_large_patch4_window7_224') +def swin_large_patch4_window7_224(cfg, **kwargs): + """ Swin-L @ 224x224, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer('swin_large_patch4_window7_224', cfg, **model_kwargs) + + +@backbone('swin_small_patch4_window7_224') +def swin_small_patch4_window7_224(cfg, **kwargs): + """ Swin-S @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer('swin_small_patch4_window7_224', cfg, **model_kwargs) + + +@backbone('swin_tiny_patch4_window7_224') +def swin_tiny_patch4_window7_224(cfg, **kwargs): + """ Swin-T @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer('swin_tiny_patch4_window7_224', cfg, **model_kwargs) + + +@backbone('swin_base_patch4_window12_384_in22k') +def swin_base_patch4_window12_384_in22k(cfg, **kwargs): + """ Swin-B @ 384x384, trained ImageNet-22k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer('swin_base_patch4_window12_384_in22k', cfg, **model_kwargs) + + +@backbone('swin_base_patch4_window7_224_in22k') +def swin_base_patch4_window7_224_in22k(cfg, **kwargs): + """ Swin-B @ 224x224, trained ImageNet-22k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer('swin_base_patch4_window7_224_in22k', cfg, **model_kwargs) + + +@backbone('swin_large_patch4_window12_384_in22k') +def swin_large_patch4_window12_384_in22k(cfg, **kwargs): + """ Swin-L @ 384x384, trained ImageNet-22k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer('swin_large_patch4_window12_384_in22k', cfg, **model_kwargs) + + +@backbone('swin_large_patch4_window7_224_in22k') +def swin_large_patch4_window7_224_in22k(cfg, **kwargs): + """ Swin-L @ 224x224, trained ImageNet-22k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer('swin_large_patch4_window7_224_in22k', cfg, **model_kwargs) \ No newline at end of file diff --git a/fgvclib/models/heads/gcn_combiner.py b/fgvclib/models/heads/gcn_combiner.py index 580f278..aa10a6b 100644 --- a/fgvclib/models/heads/gcn_combiner.py +++ b/fgvclib/models/heads/gcn_combiner.py @@ -97,6 +97,6 @@ def forward(self, x): @head("GCN_combiner") -def GCN_combiner(cfg: dict): +def GCN_combiner(cfg: dict, **kwargs): return GCNCombiner(total_num_selects=cfg['total_num_selects'], num_classes=cfg['num_classes'], fpn_size=cfg['fpn_size']) diff --git a/fgvclib/models/sotas/swin_T.py b/fgvclib/models/sotas/swin_T.py index f35cff5..281e571 100644 --- a/fgvclib/models/sotas/swin_T.py +++ b/fgvclib/models/sotas/swin_T.py @@ -5,7 +5,9 @@ from .sota import FGVCSOTA from fgvclib.models.sotas import fgvcmodel - +from fgvclib.criterions.utils import LossItem +from fgvclib.models.encoders.fpn import FPN +from fgvclib.models.necks.weakly_selector import WeaklySelector @fgvcmodel("Swin_T") class Swin_T(FGVCSOTA): @@ -13,15 +15,44 @@ def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: super().__init__(cfg, backbone, encoder, necks, heads, criterions) ### get hidden feartues size - img_size = self.args["img_size"] - num_classes = cfg.CLASS_NUM - rand_in = torch.randn(1, 3, img_size, img_size) - outs = self.backbone(rand_in) - fpn_size = self.args["fpn_size"] + + + self.num_classes = cfg.CLASS_NUM + self.use_fpn = self.args["use_fpn"] + self.lambda_s = self.args["lambda_s"] + self.lambda_n = self.args["lambda_n"] + self.lambda_b = self.args["lambda_b"] + self.lambda_c = self.args["lambda_c"] + self.use_combiner = self.args["use_combiner"] + self.update_freq = self.args["update_freq"] + num_select = self.args["num_select"] + num_select = { + "layer1": 2048, + "layer2": 512, + "layer3": 128, + "layer4": 32 + } + + input_size = self.args["img_size"] + rand_in = torch.randn(1, 3, input_size, input_size) + backbone_outs = self.backbone(rand_in) + + if self.use_fpn: + + fpn_size = self.args["fpn_size"] + self.encoder = FPN(inputs=backbone_outs, fpn_size=fpn_size, proj_type="Linear", upsample_type="Conv") + fpn_outs = self.encoder(backbone_outs) + else: + fpn_outs = backbone_outs + fpn_size = None + + + self.necks = WeaklySelector(inputs=fpn_outs, num_classes=self.num_classes, num_select=num_select, fpn_size=fpn_size) + ### = = = = = FPN = = = = = self.fpn = self.encoder - self.build_fpn_classifier(outs, fpn_size, num_classes) + self.build_fpn_classifier(backbone_outs, fpn_size, self.num_classes) ### = = = = = Selector = = = = = self.selector = self.necks @@ -31,15 +62,15 @@ def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: ### just original backbone if not self.fpn and (not self.combiner): - for name in outs: - fs_size = outs[name].size() + for name in backbone_outs: + fs_size = backbone_outs[name].size() if len(fs_size) == 3: out_size = fs_size.size(-1) elif len(fs_size) == 4: out_size = fs_size.size(1) else: raise ValueError("The size of output dimension of previous must be 3 or 4.") - self.classifier = nn.Linear(out_size, num_classes) + self.classifier = nn.Linear(out_size, self.num_classes) def build_fpn_classifier(self, inputs: dict, fpn_size: int, num_classes: int): for name in inputs: @@ -69,7 +100,7 @@ def fpn_predict(self, x: dict, logits: dict): logits[name] = getattr(self, "fpn_classifier_" + name)(logit) logits[name] = logits[name].transpose(1, 2).contiguous() # transpose - def forward(self, x: torch.Tensor): + def infer(self, x: torch.Tensor): logits = {} x = self.forward_backbone(x) @@ -86,7 +117,7 @@ def forward(self, x: torch.Tensor): logits['comb_outs'] = comb_outs return logits - if self.selector or self.fpn: + if self.selector or self.fpn: return logits ### original backbone (only predict final selected layer) @@ -102,3 +133,57 @@ def forward(self, x: torch.Tensor): logits['ori_out'] = logits return logits + + def forward(self, x, target): + batch_size = x.shape[0] + device = x.device + logits = self.infer(x) + losses = list() + + for name in logits: + + if "select_" in name: + if not self.use_selection: + raise ValueError("Selector not use here.") + if self.lambda_s != 0: + S = logits[name].size(1) + logit = logits[name].view(-1, self.num_classes).contiguous() + loss_s = nn.CrossEntropyLoss()(logit, target.unsqueeze(1).repeat(1, S).flatten(0)) + losses.append(LossItem(name="loss_s", value=loss_s, weight=self.lambda_s)) + + + + elif "drop_" in name: + if not self.use_selection: + raise ValueError("Selector not use here.") + + if self.lambda_n != 0: + S = logits[name].size(1) + logit = logits[name].view(-1, self.num_classes).contiguous() + n_preds = nn.Tanh()(logit) + labels_0 = (torch.zeros([batch_size * S, self.num_classes]) - 1).to(device) + loss_n = nn.MSELoss()(n_preds, labels_0) + losses.append(LossItem(name="loss_n", value=loss_n, weight=self.lambda_n)) + + + elif "layer" in name: + if not self.use_fpn: + raise ValueError("FPN not use here.") + if self.lambda_b != 0: + ### here using 'layer1'~'layer4' is default setting, you can change to your own + loss_b = nn.CrossEntropyLoss()(logits[name].mean(1), target) + losses.append(LossItem(name="loss_b", value=loss_b, weight=self.lambda_b)) + + elif "comb_outs" in name: + if not self.use_combiner: + raise ValueError("Combiner not use here.") + + if self.lambda_c != 0: + loss_c = nn.CrossEntropyLoss()(logits[name], target) + losses.append(LossItem(name="loss_c", value=loss_c, weight=self.lambda_c)) + + elif "ori_out" in name: + loss_ori = F.cross_entropy(logits[name], target) + losses.append(LossItem(name="loss_ori", value=loss_ori, weight=1.0)) + + return diff --git a/fgvclib/transforms/base_transforms.py b/fgvclib/transforms/base_transforms.py index 06317bc..6981617 100644 --- a/fgvclib/transforms/base_transforms.py +++ b/fgvclib/transforms/base_transforms.py @@ -30,3 +30,11 @@ def normalize(cfg: dict): @transform("color_jitter") def color_jitter(cfg: dict): return transforms.ColorJitter(brightness=cfg['brightness'], saturation=cfg['saturation']) + +@transform('randomApply_gaussianBlur') +def randomApply_gaussianBlur(cfg: dict): + return transforms.RandomApply([transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 5))], p=cfg['prob']) + +@transform('randomAdjust_sharpness') +def randomAdjust_sharpness(cfg:dict): + return transforms.RandomAdjustSharpness(sharpness_factor=cfg['sharpness_factor'], p=cfg['prob']) From 4b41841144ffd677334d4a05ba8254834afcd086 Mon Sep 17 00:00:00 2001 From: xin-ran-w Date: Thu, 19 Jan 2023 22:43:04 +0800 Subject: [PATCH 05/10] update swin --- configs/swin_transformer/swin_transformer.yml | 15 +- fgvclib/apis/__init__.py | 5 +- fgvclib/apis/build.py | 23 ++- fgvclib/apis/evaluate_model.py | 49 ------- fgvclib/configs/config.py | 1 + fgvclib/models/backbones/swin_transformer.py | 3 +- .../sotas/{swin_T.py => swin_transformer.py} | 118 ++++++++-------- fgvclib/utils/evaluate_function/__init__.py | 32 +++++ .../utils/evaluate_function/evaluate_model.py | 90 ++++++++++++ fgvclib/utils/lr_schedules/__init__.py | 6 +- .../utils/lr_schedules/adjusting_schedule.py | 2 +- .../lr_schedules/cosine_anneal_schedule.py | 4 +- .../lr_schedules/cosine_decay_schedule.py | 46 ++++-- .../lr_schedules/warmup_cosine_schedule.py | 2 +- fgvclib/utils/update_function/__init__.py | 10 +- .../update_swin_transformer.py | 69 +++++++++ fgvclib/utils/update_function/update_swint.py | 131 ------------------ fgvclib/utils/update_strategy/__init__.py | 10 +- main.py | 10 +- 19 files changed, 333 insertions(+), 293 deletions(-) delete mode 100644 fgvclib/apis/evaluate_model.py rename fgvclib/models/sotas/{swin_T.py => swin_transformer.py} (61%) create mode 100644 fgvclib/utils/evaluate_function/__init__.py create mode 100644 fgvclib/utils/evaluate_function/evaluate_model.py create mode 100644 fgvclib/utils/update_function/update_swin_transformer.py delete mode 100644 fgvclib/utils/update_function/update_swint.py diff --git a/configs/swin_transformer/swin_transformer.yml b/configs/swin_transformer/swin_transformer.yml index 259d521..fbb7f00 100644 --- a/configs/swin_transformer/swin_transformer.yml +++ b/configs/swin_transformer/swin_transformer.yml @@ -26,7 +26,7 @@ DATASET: NUM_WORKERS: 1 MODEL: - NAME: "Swin_T" + NAME: "SwinTransformer" CLASS_NUM: 200 ARGS: - img_size: 384 @@ -127,19 +127,20 @@ OPTIMIZER: WEIGHT_DECAY: 0.0005 LR_SCHEDULE: - NAME: "warmup_linear_schedule" + NAME: "warmup_cosine_decay_schedule" ARGS: - warmup_steps: 800 - - total_steps: 10000 - - MAX_LR: 0.005 + - max_lr: 0.005 + - max_epochs: 50 + - decay_type: 1 ITERATION_NUM: ~ EPOCH_NUM: 50 START_EPOCH: 0 AMP: True -UPDATE_STRATEGY: ~ - - +UPDATE_STRATEGY: "" +UPDATE_FUNCTION: "update_swin_transformer" +EVALUATE_FUNCTION: "swin_transformer_evaluate" # Validation details PER_ITERATION: ~ diff --git a/fgvclib/apis/__init__.py b/fgvclib/apis/__init__.py index d716eee..962b610 100644 --- a/fgvclib/apis/__init__.py +++ b/fgvclib/apis/__init__.py @@ -1,4 +1,3 @@ from .build import * -from .evaluate_model import evaluate_model -from .save_model import save_model -from .seed import set_seed +from .save_model import * +from .seed import * diff --git a/fgvclib/apis/build.py b/fgvclib/apis/build.py index b0ffb6c..9b192d3 100644 --- a/fgvclib/apis/build.py +++ b/fgvclib/apis/build.py @@ -31,6 +31,7 @@ from fgvclib.utils.interpreter import get_interpreter, Interpreter from fgvclib.utils.lr_schedules import get_lr_schedule, LRSchedule from fgvclib.utils.update_function import get_update_function +from fgvclib.utils.evaluate_function import get_evaluate_function def build_model(model_cfg: CfgNode) -> FGVCSOTA: @@ -221,7 +222,7 @@ def build_sampler(sampler_cfg: CfgNode) -> Sampler: return get_sampler(sampler_cfg.NAME) -def build_lr_schedule(optimizer, schedule_cfg: CfgNode) -> LRSchedule: +def build_lr_schedule(optimizer, schedule_cfg: CfgNode, train_loader) -> LRSchedule: r"""Build lr_schedule for training. Args: @@ -230,18 +231,30 @@ def build_lr_schedule(optimizer, schedule_cfg: CfgNode) -> LRSchedule: LRSchedule: A lr schedule. """ + batch_num_per_epoch = len(train_loader) + return get_lr_schedule(schedule_cfg.NAME)(optimizer, batch_num_per_epoch, tltd(schedule_cfg.ARGS)) - return get_lr_schedule(schedule_cfg.NAME)(optimizer, tltd(schedule_cfg.ARGS)) - -def build_update_function(cfg) -> t.Callable: +def build_update_function(cfg): r"""Build metrics for evaluation. Args: cfg (CfgNode): The root config node. Returns: - Callable: A update_function. + function: A update model function. """ return get_update_function(cfg.UPDATE_FUNCTION) + + +def build_evaluate_function(cfg): + r"""Build metrics for evaluation. + + Args: + cfg (CfgNode): The root config node. + Returns: + function: A evaluate model function. + + """ + return get_evaluate_function(cfg.EVALUATE_FUNCTION) diff --git a/fgvclib/apis/evaluate_model.py b/fgvclib/apis/evaluate_model.py deleted file mode 100644 index 1eb8552..0000000 --- a/fgvclib/apis/evaluate_model.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) 2022-present, BUPT-PRIS. - -""" - This file provides a api for evaluating FGVC algorithms. -""" - - -import torch -import torch.nn as nn -from torch.autograd import Variable -import typing as t -from fgvclib.metrics.metrics import NamedMetric - - -def evaluate_model(model:nn.Module, p_bar:t.Iterable, metrics:t.List[NamedMetric], use_cuda:bool=True) -> t.Dict: - r"""Evaluate the FGVC model. - - Args: - model (nn.Module): - The FGVC model. - p_bar (iterable): - A iterator provide test data. - metrics (List[NamedMetric]): - List of metrics. - use_cuda (boolean, optional): - Whether to use gpu. - - Returns: - dict: The result dict. - """ - - model.eval() - results = dict() - - with torch.no_grad(): - for _, (inputs, targets) in enumerate(p_bar): - if use_cuda: - inputs, targets = inputs.cuda(), targets.cuda() - inputs, targets = Variable(inputs), Variable(targets) - for metric in metrics: - _ = metric.update(model(inputs), targets) - - for metric in metrics: - result = metric.compute() - results.update({ - metric.name: round(result.item(), 3) - }) - - return results diff --git a/fgvclib/configs/config.py b/fgvclib/configs/config.py index 47d8824..c79b5aa 100644 --- a/fgvclib/configs/config.py +++ b/fgvclib/configs/config.py @@ -135,6 +135,7 @@ def __init__(self): self.cfg.PER_ITERATION = None self.cfg.PER_EPOCH = None self.cfg.METRICS = None + self.cfg.EVALUATE_FUNCTION = "general_evaluate" # Inference self.cfg.FIFTYONE = CN() diff --git a/fgvclib/models/backbones/swin_transformer.py b/fgvclib/models/backbones/swin_transformer.py index 6f5ce1f..9e28040 100644 --- a/fgvclib/models/backbones/swin_transformer.py +++ b/fgvclib/models/backbones/swin_transformer.py @@ -702,4 +702,5 @@ def swin_large_patch4_window7_224_in22k(cfg, **kwargs): """ model_kwargs = dict( patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) - return _create_swin_transformer('swin_large_patch4_window7_224_in22k', cfg, **model_kwargs) \ No newline at end of file + return _create_swin_transformer('swin_large_patch4_window7_224_in22k', cfg, **model_kwargs) + \ No newline at end of file diff --git a/fgvclib/models/sotas/swin_T.py b/fgvclib/models/sotas/swin_transformer.py similarity index 61% rename from fgvclib/models/sotas/swin_T.py rename to fgvclib/models/sotas/swin_transformer.py index 281e571..53feaa3 100644 --- a/fgvclib/models/sotas/swin_T.py +++ b/fgvclib/models/sotas/swin_transformer.py @@ -9,8 +9,8 @@ from fgvclib.models.encoders.fpn import FPN from fgvclib.models.necks.weakly_selector import WeaklySelector -@fgvcmodel("Swin_T") -class Swin_T(FGVCSOTA): +@fgvcmodel("SwinTransformer") +class SwinTransformer(FGVCSOTA): def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: nn.Module, heads: nn.Module, criterions: nn.Module): super().__init__(cfg, backbone, encoder, necks, heads, criterions) @@ -25,13 +25,8 @@ def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: self.lambda_c = self.args["lambda_c"] self.use_combiner = self.args["use_combiner"] self.update_freq = self.args["update_freq"] + self.use_selection = self.args["use_selection"] num_select = self.args["num_select"] - num_select = { - "layer1": 2048, - "layer2": 512, - "layer3": 128, - "layer4": 32 - } input_size = self.args["img_size"] rand_in = torch.randn(1, 3, input_size, input_size) @@ -134,56 +129,59 @@ def infer(self, x: torch.Tensor): return logits - def forward(self, x, target): - batch_size = x.shape[0] - device = x.device + def forward(self, x, target=None): + logits = self.infer(x) - losses = list() - - for name in logits: - - if "select_" in name: - if not self.use_selection: - raise ValueError("Selector not use here.") - if self.lambda_s != 0: - S = logits[name].size(1) - logit = logits[name].view(-1, self.num_classes).contiguous() - loss_s = nn.CrossEntropyLoss()(logit, target.unsqueeze(1).repeat(1, S).flatten(0)) - losses.append(LossItem(name="loss_s", value=loss_s, weight=self.lambda_s)) - - - - elif "drop_" in name: - if not self.use_selection: - raise ValueError("Selector not use here.") - - if self.lambda_n != 0: - S = logits[name].size(1) - logit = logits[name].view(-1, self.num_classes).contiguous() - n_preds = nn.Tanh()(logit) - labels_0 = (torch.zeros([batch_size * S, self.num_classes]) - 1).to(device) - loss_n = nn.MSELoss()(n_preds, labels_0) - losses.append(LossItem(name="loss_n", value=loss_n, weight=self.lambda_n)) - - - elif "layer" in name: - if not self.use_fpn: - raise ValueError("FPN not use here.") - if self.lambda_b != 0: - ### here using 'layer1'~'layer4' is default setting, you can change to your own - loss_b = nn.CrossEntropyLoss()(logits[name].mean(1), target) - losses.append(LossItem(name="loss_b", value=loss_b, weight=self.lambda_b)) - - elif "comb_outs" in name: - if not self.use_combiner: - raise ValueError("Combiner not use here.") - - if self.lambda_c != 0: - loss_c = nn.CrossEntropyLoss()(logits[name], target) - losses.append(LossItem(name="loss_c", value=loss_c, weight=self.lambda_c)) - - elif "ori_out" in name: - loss_ori = F.cross_entropy(logits[name], target) - losses.append(LossItem(name="loss_ori", value=loss_ori, weight=1.0)) - - return + + if not self.training: + return logits + else: + + losses = list() + batch_size = x.shape[0] + device = x.device + for name in logits: + + if "select_" in name: + if not self.use_selection: + raise ValueError("Selector not use here.") + if self.lambda_s != 0: + S = logits[name].size(1) + logit = logits[name].view(-1, self.num_classes).contiguous() + loss_s = nn.CrossEntropyLoss()(logit, target.unsqueeze(1).repeat(1, S).flatten(0)) + losses.append(LossItem(name="loss_s", value=loss_s, weight=self.lambda_s)) + + elif "drop_" in name: + if not self.use_selection: + raise ValueError("Selector not use here.") + + if self.lambda_n != 0: + S = logits[name].size(1) + logit = logits[name].view(-1, self.num_classes).contiguous() + n_preds = nn.Tanh()(logit) + labels_0 = (torch.zeros([batch_size * S, self.num_classes]) - 1).to(device) + loss_n = nn.MSELoss()(n_preds, labels_0) + losses.append(LossItem(name="loss_n", value=loss_n, weight=self.lambda_n)) + + + elif "layer" in name: + if not self.use_fpn: + raise ValueError("FPN not use here.") + if self.lambda_b != 0: + ### here using 'layer1'~'layer4' is default setting, you can change to your own + loss_b = nn.CrossEntropyLoss()(logits[name].mean(1), target) + losses.append(LossItem(name="loss_b", value=loss_b, weight=self.lambda_b)) + + elif "comb_outs" in name: + if not self.use_combiner: + raise ValueError("Combiner not use here.") + + if self.lambda_c != 0: + loss_c = nn.CrossEntropyLoss()(logits[name], target) + losses.append(LossItem(name="loss_c", value=loss_c, weight=self.lambda_c)) + + elif "ori_out" in name: + loss_ori = F.cross_entropy(logits[name], target) + losses.append(LossItem(name="loss_ori", value=loss_ori, weight=1.0)) + + return logits, losses diff --git a/fgvclib/utils/evaluate_function/__init__.py b/fgvclib/utils/evaluate_function/__init__.py new file mode 100644 index 0000000..c110e79 --- /dev/null +++ b/fgvclib/utils/evaluate_function/__init__.py @@ -0,0 +1,32 @@ +import os +import importlib + +__EVAL_FN_DICT__ = {} + + +def get_evaluate_function(name): + r"""Return the evaluate function with the given name. + Args: + name (str): + The name of evaluate function. + Return: + (function): The evaluate function. + """ + + return __EVAL_FN_DICT__[name] + +def evaluate_function(name): + + def register_function_fn(cls): + if name in __EVAL_FN_DICT__: + raise ValueError("Name %s already registered!" % name) + __EVAL_FN_DICT__[name] = cls + return cls + + return register_function_fn + +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith('.py') and not file.startswith('_'): + module_name = file[:file.find('.py')] + module = importlib.import_module('fgvclib.utils.evaluate_function.' + module_name) + \ No newline at end of file diff --git a/fgvclib/utils/evaluate_function/evaluate_model.py b/fgvclib/utils/evaluate_function/evaluate_model.py new file mode 100644 index 0000000..8bb1a78 --- /dev/null +++ b/fgvclib/utils/evaluate_function/evaluate_model.py @@ -0,0 +1,90 @@ +# Copyright (c) 2022-present, BUPT-PRIS. + +""" + This file provides a api for evaluating FGVC algorithms. +""" + + +import torch +import torch.nn as nn +from torch.autograd import Variable +import typing as t + +from fgvclib.metrics.metrics import NamedMetric +from . import evaluate_function + +@evaluate_function("general_evaluate") +def general_evaluate(model:nn.Module, p_bar:t.Iterable, metrics:t.List[NamedMetric], use_cuda:bool=True) -> t.Dict: + r"""Evaluate the FGVC model. + + Args: + model (nn.Module): + The FGVC model. + p_bar (iterable): + A iterator provide test data. + metrics (List[NamedMetric]): + List of metrics. + use_cuda (boolean, optional): + Whether to use gpu. + + Returns: + dict: The result dict. + """ + + model.eval() + results = dict() + + with torch.no_grad(): + for _, (inputs, targets) in enumerate(p_bar): + if use_cuda: + inputs, targets = inputs.cuda(), targets.cuda() + inputs, targets = Variable(inputs), Variable(targets) + logits = model(inputs) + for metric in metrics: + _ = metric.update(logits, targets) + + for metric in metrics: + result = metric.compute() + results.update({ + metric.name: round(result.item(), 3) + }) + + return results + +@evaluate_function("swin_transformer_evaluate") +def swin_transformer_evaluate(model:nn.Module, p_bar:t.Iterable, metrics:t.List[NamedMetric], use_cuda:bool=True) -> t.Dict: + r"""Evaluate the FGVC model. + + Args: + model (nn.Module): + The FGVC model. + p_bar (iterable): + A iterator provide test data. + metrics (List[NamedMetric]): + List of metrics. + use_cuda (boolean, optional): + Whether to use gpu. + + Returns: + dict: The result dict. + """ + + model.eval() + results = dict() + + with torch.no_grad(): + for _, (inputs, targets) in enumerate(p_bar): + if use_cuda: + inputs, targets = inputs.cuda(), targets.cuda() + inputs, targets = Variable(inputs), Variable(targets) + logits = model(inputs) + for metric in metrics: + _ = metric.update(logits["comb_outs"], targets) + + for metric in metrics: + result = metric.compute() + results.update({ + metric.name: round(result.item(), 3) + }) + + return results diff --git a/fgvclib/utils/lr_schedules/__init__.py b/fgvclib/utils/lr_schedules/__init__.py index 9720003..cc8eb85 100644 --- a/fgvclib/utils/lr_schedules/__init__.py +++ b/fgvclib/utils/lr_schedules/__init__.py @@ -9,12 +9,12 @@ def get_lr_schedule(name) -> LRSchedule: - r"""Return the dataset with the given name. + r"""Return the learning rate schedule with the given name. Args: name (str): - The name of dataset. + The name of learning rate schedule. Return: - (FGVCDataset): The dataset contructor method. + (LRSchedule): The learning rate schedule contructor. """ return __LR_SCHEDULE_DICT__[name] diff --git a/fgvclib/utils/lr_schedules/adjusting_schedule.py b/fgvclib/utils/lr_schedules/adjusting_schedule.py index f471c11..37ba952 100644 --- a/fgvclib/utils/lr_schedules/adjusting_schedule.py +++ b/fgvclib/utils/lr_schedules/adjusting_schedule.py @@ -20,5 +20,5 @@ def step(self, optimizer, batch_idx, current_epoch, total_batch, **kwargs): @lr_schedule("adjusting_schedule") -def adjusting_schedule(optimizer, cfg:dict): +def adjusting_schedule(optimizer, batch_num_per_epoch, cfg:dict): return Adjusting_Schedule(optimizer, **cfg) diff --git a/fgvclib/utils/lr_schedules/cosine_anneal_schedule.py b/fgvclib/utils/lr_schedules/cosine_anneal_schedule.py index 6ca37cd..db4af36 100644 --- a/fgvclib/utils/lr_schedules/cosine_anneal_schedule.py +++ b/fgvclib/utils/lr_schedules/cosine_anneal_schedule.py @@ -6,7 +6,7 @@ class CosineAnnealSchedule(LRSchedule): - def __init__(self, optimizer) -> None: + def __init__(self, optimizer, **kwargs) -> None: super().__init__(optimizer) self.update_level = 'epoch' @@ -20,5 +20,5 @@ def step(self, current_epoch, total_epoch, **kwargs): pg['lr'] = float(current_lr / 2 * cos_out) @lr_schedule("cosine_anneal_schedule") -def cosine_anneal_schedule(optimizer, cfg:dict): +def cosine_anneal_schedule(optimizer, batch_num_per_epoch, cfg:dict): return CosineAnnealSchedule(optimizer, **cfg) diff --git a/fgvclib/utils/lr_schedules/cosine_decay_schedule.py b/fgvclib/utils/lr_schedules/cosine_decay_schedule.py index ef77d78..ff6687c 100644 --- a/fgvclib/utils/lr_schedules/cosine_decay_schedule.py +++ b/fgvclib/utils/lr_schedules/cosine_decay_schedule.py @@ -1,21 +1,39 @@ import math import numpy as np +from . import lr_schedule +from .lr_schedule import LRSchedule -def cosine_decay(max_epochs, warmup_batchs, max_lr, batchs: int, decay_type: int = 1): - total_batchs = max_epochs * batchs - iters = np.arange(total_batchs - warmup_batchs) +class WarmupCosineDecaySchedule(LRSchedule): + + def __init__(self, optimizer, max_epochs, warmup_steps, max_lr, batch_num_per_epoch, decay_type) -> None: + super().__init__(optimizer) - if decay_type == 1: - schedule = np.array([1e-12 + 0.5 * (max_lr - 1e-12) * (1 + math.cos(math.pi * t / total_batchs)) for t in iters]) - elif decay_type == 2: - schedule = max_lr * np.array([math.cos(7 * math.pi * t / (16 * total_batchs)) for t in iters]) - else: - raise ValueError("Not support this decay type") + total_batchs = max_epochs * batch_num_per_epoch + iters = np.arange(total_batchs - warmup_steps) + self.last_iter = -1 + self.update_level = 'batch' - if warmup_batchs > 0: - warmup_lr_schedule = np.linspace(1e-9, max_lr, warmup_batchs) - schedule = np.concatenate((warmup_lr_schedule, schedule)) + if decay_type == 1: + schedule = np.array([1e-12 + 0.5 * (max_lr - 1e-12) * (1 + math.cos(math.pi * t / total_batchs)) for t in iters]) + elif decay_type == 2: + schedule = max_lr * np.array([math.cos(7 * math.pi * t / (16 * total_batchs)) for t in iters]) + else: + raise ValueError("Not support this decay type") - return schedule - \ No newline at end of file + if warmup_steps > 0: + warmup_lr_schedule = np.linspace(1e-9, max_lr, warmup_steps) + schedule = np.concatenate((warmup_lr_schedule, schedule)) + + self.schedule = schedule + self.step() + + def step(self): + self.last_iter += 1 + for pg in self.optimizer.param_groups: + pg['lr'] = self.schedule[self.last_iter] + + +@lr_schedule("warmup_cosine_decay_schedule") +def warmup_cosine_decay_schedule(optimizer, batch_num_per_epoch, cfg:dict): + return WarmupCosineDecaySchedule(optimizer=optimizer, batch_num_per_epoch=batch_num_per_epoch, **cfg) diff --git a/fgvclib/utils/lr_schedules/warmup_cosine_schedule.py b/fgvclib/utils/lr_schedules/warmup_cosine_schedule.py index ddb8854..71fd187 100644 --- a/fgvclib/utils/lr_schedules/warmup_cosine_schedule.py +++ b/fgvclib/utils/lr_schedules/warmup_cosine_schedule.py @@ -23,6 +23,6 @@ def step(self, **kwargs): @lr_schedule("warmup_linear_schedule") -def warmup_linear_schedule(optimizer, cfg:dict): +def warmup_linear_schedule(optimizer, batch_num_per_epoch, cfg:dict): return WarmupLinearSchedule(optimizer=optimizer, **cfg) diff --git a/fgvclib/utils/update_function/__init__.py b/fgvclib/utils/update_function/__init__.py index 374fe27..b5c59c6 100644 --- a/fgvclib/utils/update_function/__init__.py +++ b/fgvclib/utils/update_function/__init__.py @@ -1,19 +1,17 @@ import os import importlib -from fgvclib.utils.lr_schedules.lr_schedule import LRSchedule - __UPDATE_FN_DICT__ = {} -def get_update_function(name) -> LRSchedule: - r"""Return the dataset with the given name. +def get_update_function(name): + r"""Return the update model function with the given name. Args: name (str): - The name of dataset. + The name of update function. Return: - (FGVCDataset): The dataset contructor method. + (function): The update function. """ return __UPDATE_FN_DICT__[name] diff --git a/fgvclib/utils/update_function/update_swin_transformer.py b/fgvclib/utils/update_function/update_swin_transformer.py new file mode 100644 index 0000000..31a1be8 --- /dev/null +++ b/fgvclib/utils/update_function/update_swin_transformer.py @@ -0,0 +1,69 @@ +from typing import Iterable +from torch.autograd import Variable +from torch.optim import Optimizer +from torch.cuda.amp import autocast, GradScaler +import torch.nn as nn +import torch +import torch.functional as F + +from . import update_function +from fgvclib.utils.logger import Logger +from fgvclib.criterions import compute_loss_value, detach_loss_value + + + +@update_function("update_swin_transformer") +def update_swin_transformer(model: nn.Module, optimizer: Optimizer, pbar:Iterable, lr_schedule=None, + strategy:str="update_swint", use_cuda:bool=True, logger:Logger=None, + epoch:int=None, total_epoch:int=None, amp:bool=False, use_selection=False, cfg=None, **kwargs, +): + scaler = GradScaler() + + optimizer.zero_grad() + + for batch_id, (inputs, targets) in enumerate(pbar): + model.train() + if use_cuda: + inputs, targets = inputs.cuda(), targets.cuda() + inputs, targets = Variable(inputs), Variable(targets) + + with autocast(): + """ + [Model Return] + FPN + Selector + Combiner --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1', 'comb_outs' + FPN + Selector --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1' + FPN --> return 'layer1', 'layer2', 'layer3', 'layer4' (depend on your setting) + ~ --> return 'ori_out' + + [Retuen Tensor] + 'preds_0': logit has not been selected by Selector. + 'preds_1': logit has been selected by Selector. + 'comb_outs': The prediction of combiner. + """ + out, losses = model(inputs, targets) + total_loss = compute_loss_value(losses) + total_loss /= model.update_freq + + if amp: + scaler.scale(total_loss).backward() + else: + total_loss.backward() + + losses_info = detach_loss_value(losses) + losses_info.update({"iter_loss": total_loss.item()}) + pbar.set_postfix(losses_info) + if (batch_id + 1) % model.update_freq == 0: + if amp: + scaler.step(optimizer) + scaler.update() # next batch + else: + optimizer.step() + optimizer.zero_grad() + + if lr_schedule.update_level == "batch": + lr_schedule.step() + + if lr_schedule.update_level == "epoch": + lr_schedule.step() \ No newline at end of file diff --git a/fgvclib/utils/update_function/update_swint.py b/fgvclib/utils/update_function/update_swint.py deleted file mode 100644 index f69109c..0000000 --- a/fgvclib/utils/update_function/update_swint.py +++ /dev/null @@ -1,131 +0,0 @@ -from typing import Iterable -from torch.autograd import Variable -from torch.optim import Optimizer -from torch.cuda.amp import autocast, GradScaler -import torch.nn as nn -import torch -import torch.functional as F - -from . import update_function -from fgvclib.utils.logger import Logger - - - -@update_function("update_swint") -def update_swint(model: nn.Module, optimizer: Optimizer, pbar:Iterable, lr_schedule=None, - strategy:str="update_swint", use_cuda:bool=True, logger:Logger=None, - epoch:int=None, total_epoch:int=None, amp:bool=False, use_selection=False, cfg=None, **kwargs, -): - scaler = GradScaler() - lambda_s = cfg.lambda_s - lambda_n = cfg.lambda_n - num_classes = cfg.MODEL.CLASS_NUM - use_fpn = cfg.MODEL.ARGS["use_fpn"] - lambda_b = cfg.lambda_b - use_combiner = cfg.use_combiner - lambda_c = cfg.lambda_c - update_freq = cfg.update_freq - - optimizer.zero_grad() - total_batchs = len(pbar) # just for log - show_progress = [x / 10 for x in range(11)] # just for log - progress_i = 0 - for batch_id, (datas, labels) in enumerate(pbar): - model.train() - """ = = = = adjust learning rate = = = = """ - iterations = epoch * len(pbar) + batch_id - - for param_group in optimizer.param_groups: - param_group["lr"] = lr_schedule[iterations] - - batch_size = labels.size(0) - - """ = = = = forward and calculate loss = = = = """ - datas, labels = datas.cuda(), labels.cuda() - datas, labels = Variable(datas), Variable(labels) - with autocast(): - """ - [Model Return] - FPN + Selector + Combiner --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) - 'preds_0', 'preds_1', 'comb_outs' - FPN + Selector --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) - 'preds_0', 'preds_1' - FPN --> return 'layer1', 'layer2', 'layer3', 'layer4' (depend on your setting) - ~ --> return 'ori_out' - - [Retuen Tensor] - 'preds_0': logit has not been selected by Selector. - 'preds_1': logit has been selected by Selector. - 'comb_outs': The prediction of combiner. - """ - outs = model(datas) - - loss = 0. - for name in outs: - - if "select_" in name: - if not use_selection: - raise ValueError("Selector not use here.") - if lambda_s != 0: - S = outs[name].size(1) - logit = outs[name].view(-1, num_classes).contiguous() - loss_s = nn.CrossEntropyLoss()(logit, - labels.unsqueeze(1).repeat(1, S).flatten(0)) - loss += lambda_s * loss_s - else: - loss_s = 0.0 - - elif "drop_" in name: - if not use_selection: - raise ValueError("Selector not use here.") - - if lambda_n != 0: - S = outs[name].size(1) - logit = outs[name].view(-1, num_classes).contiguous() - n_preds = nn.Tanh()(logit) - labels_0 = torch.zeros([batch_size * S, num_classes]) - 1 - labels_0 = labels_0.cuda() - loss_n = nn.MSELoss()(n_preds, labels_0) - loss += lambda_n * loss_n - else: - loss_n = 0.0 - - elif "layer" in name: - if not use_fpn: - raise ValueError("FPN not use here.") - if lambda_b != 0: - ### here using 'layer1'~'layer4' is default setting, you can change to your own - loss_b = nn.CrossEntropyLoss()(outs[name].mean(1), labels) - loss += lambda_b * loss_b - else: - loss_b = 0.0 - - elif "comb_outs" in name: - if not use_combiner: - raise ValueError("Combiner not use here.") - - if lambda_c != 0: - loss_c = nn.CrossEntropyLoss()(outs[name], labels) - loss += lambda_c * loss_c - - elif "ori_out" in name: - loss_ori = F.cross_entropy(outs[name], labels) - loss += loss_ori - - loss /= update_freq - - """ = = = = calculate gradient = = = = """ - if amp: - scaler.scale(loss).backward() - else: - loss.backward() - - """ = = = = update model = = = = """ - if (batch_id + 1) % update_freq == 0: - if amp: - scaler.step(optimizer) - scaler.update() # next batch - else: - optimizer.step() - optimizer.zero_grad() - \ No newline at end of file diff --git a/fgvclib/utils/update_strategy/__init__.py b/fgvclib/utils/update_strategy/__init__.py index 02f2c26..8d2ce40 100644 --- a/fgvclib/utils/update_strategy/__init__.py +++ b/fgvclib/utils/update_strategy/__init__.py @@ -1,19 +1,17 @@ import os import importlib -from fgvclib.utils.lr_schedules.lr_schedule import LRSchedule -from torch.optim import SGD __UPDATE_ST_DICT__ = {} -def get_update_strategy(name) -> LRSchedule: - r"""Return the dataset with the given name. +def get_update_strategy(name): + r"""Return the update strategy with the given name. Args: name (str): - The name of dataset. + The name of update strategy. Return: - (FGVCDataset): The dataset contructor method. + (function): The update strategy function. """ return __UPDATE_ST_DICT__[name] diff --git a/main.py b/main.py index 7a03241..aa87ca0 100644 --- a/main.py +++ b/main.py @@ -8,7 +8,6 @@ from fgvclib.apis import * from fgvclib.configs import FGVCConfig from fgvclib.utils import init_distributed_mode -from fgvclib.utils.update_function.general_update import general_update def train(cfg: CfgNode): @@ -81,10 +80,12 @@ def train(cfg: CfgNode): metrics = build_metrics(cfg.METRICS) - lr_schedule = build_lr_schedule(optimizer, cfg.LR_SCHEDULE) + lr_schedule = build_lr_schedule(optimizer, cfg.LR_SCHEDULE, train_loader) update_fn = build_update_function(cfg) + evaluate_fn = build_evaluate_function(cfg) + for epoch in range(cfg.START_EPOCH, cfg.EPOCH_NUM): if args.distributed: train_sampler.set_epoch(epoch) @@ -104,7 +105,7 @@ def train(cfg: CfgNode): logger(f'Epoch: {epoch + 1} / {cfg.EPOCH_NUM} Testing ') - acc = evaluate_model(model, test_bar, metrics=metrics, use_cuda=cfg.USE_CUDA) + acc = evaluate_fn(model, test_bar, metrics=metrics, use_cuda=cfg.USE_CUDA) print(acc) logger("Evalution Result:") logger(acc) @@ -138,7 +139,8 @@ def predict(cfg: CfgNode): pbar = tqdm(loader) metrics = build_metrics(cfg.METRICS) - acc = evaluate_model(model, pbar, metrics=metrics, use_cuda=cfg.USE_CUDA) + evaluate_fn = build_evaluate_function(cfg) + acc = evaluate_fn(model, pbar, metrics=metrics, use_cuda=cfg.USE_CUDA) print(acc) From 1736684a4f16912dc9f5fe11c538648fbaf34122 Mon Sep 17 00:00:00 2001 From: xin-ran-w Date: Sun, 5 Feb 2023 23:17:27 +0800 Subject: [PATCH 06/10] fix bugs --- configs/api-net/api-net-resnet101.yml | 155 ------------------ configs/cal/cal_resnet101.yml | 9 +- configs/mutual_channel_loss/mcl_vgg16.yml | 4 +- .../pmg_resnet50.yml | 4 +- .../pmg_v2_resnet50.yml | 4 +- configs/resnet/resnet50.yml | 4 +- configs/resnet/resnet50_cutmix.yml | 4 +- configs/swin_transformer/swin_transformer.yml | 24 +-- fgvclib/apis/build.py | 5 +- fgvclib/configs/config.py | 4 +- fgvclib/metrics/metrics.py | 2 +- fgvclib/models/sotas/cal.py | 35 +++- fgvclib/models/sotas/swin_transformer.py | 10 +- .../utils/evaluate_function/evaluate_model.py | 38 +++++ .../utils/lr_schedules/adjusting_schedule.py | 4 +- main.py | 2 - 16 files changed, 117 insertions(+), 191 deletions(-) delete mode 100644 configs/api-net/api-net-resnet101.yml diff --git a/configs/api-net/api-net-resnet101.yml b/configs/api-net/api-net-resnet101.yml deleted file mode 100644 index c41eb20..0000000 --- a/configs/api-net/api-net-resnet101.yml +++ /dev/null @@ -1,155 +0,0 @@ -EXP_NAME: "APINet-Resnet101" - -RESUME_WEIGHT: ~ - -WEIGHT: - NAME: "APINet_resnet101_new.pth" - SAVE_DIR: "/mnt/sdb/data/wangxinran/weight/fgvclib" - -LOGGER: - NAME: "txt_logger" - -DATASET: - NAME: "CUB_200_2011" - ROOT: "/mnt/sdb/data/wangxinran/dataset" - TRAIN: - BATCH_SIZE: 2 - POSITIVE: 0 - PIN_MEMORY: True - SHUFFLE: True - NUM_WORKERS: 10 - TEST: - BATCH_SIZE: 100 - POSITIVE: 0 - PIN_MEMORY: True - SHUFFLE: False - NUM_WORKERS: 10 - -SAMPLER: - TRAIN: - NAME: "BalancedBatchSampler" - IS_BATCH_SAMPLER: True - ARGS: - - n_samples: 2 - - n_classes: 10 - TEST: - NAME: "SequentialSampler" - IS_BATCH_SAMPLER: False - ARGS: ~ - -MODEL: - NAME: "APINet" - CLASS_NUM: 200 - CRITERIONS: - - name: "score_rank_regular_loss" - args: [] - w: 1.0 - - name: "cross_entropy_loss" - args: [] - w: 1.0 - BACKBONE: - NAME: "resnet101" - ARGS: - - pretrained: True - - del_keys: [] - ENCODER: - NAME: "avg_pooling_2d" - ARGS: - - kernel_size: 14 - - stride: 1 - NECKS: - NAME: "pairwise_interaction" - ARGS: - - in_dim: 4096 - - hid_dim: 512 - - out_dim: 2048 - HEADS: - NAME: "classifier_drop_1fc" - ARGS: - - in_dim: - - 2048 - -TRANSFORMS: - TRAIN: - - name: "resize" - size: - - 512 - - 512 - - name: "random_crop" - size: - - 448 - - 448 - padding: ~ - - name: "random_horizontal_flip" - prob: 0.5 - - name: "to_tensor" - - name: "normalize" - mean: - - 0.485 - - 0.456 - - 0.406 - std: - - 0.229 - - 0.224 - - 0.225 - TEST: - - name: "resize" - size: - - 512 - - 512 - - name: "center_crop" - size: - - 448 - - 448 - - name: "to_tensor" - - name: "normalize" - mean: - - 0.485 - - 0.456 - - 0.406 - std: - - 0.229 - - 0.224 - - 0.225 - -OPTIMIZER: - NAME: "SGD" - MOMENTUM: 0.9 - LR: - backbone: 0.002 - encoder: 0.01 - necks: 0.01 - heads: 0.01 - -ITERATION_NUM: ~ -EPOCH_NUM: 100 -START_EPOCH: 0 -UPDATE_STRATEGY: "general_updating" - - -# Validation details -PER_ITERATION: ~ -PER_EPOCH: ~ -METRICS: - - name: "accuracy(topk=1)" - metric: "accuracy" - top_k: 1 - threshold: ~ - - name: "accuracy(topk=5)" - metric: "accuracy" - top_k: 5 - threshold: ~ - - name: "recall(threshold=0.5)" - metric: "recall" - top_k: 1 - threshold: 0.5 - - name: "precision(threshold=0.5)" - metric: "precision" - top_k: 1 - threshold: 0.5 - -INTERPRETER: - NAME: "cam" - METHOD: "gradcam" - TARGET_LAYERS: - - "backbone.layer4" diff --git a/configs/cal/cal_resnet101.yml b/configs/cal/cal_resnet101.yml index 05fce78..f954df0 100644 --- a/configs/cal/cal_resnet101.yml +++ b/configs/cal/cal_resnet101.yml @@ -105,8 +105,9 @@ TRANSFORMS: OPTIMIZER: NAME: SGD - MOMENTUM: 0.9 - WEIGHT_DECAY: 1e-5 + ARGS: + - momentum: 0.9 + - weight_decay: 0.0005 LR: backbone: 0.001 encoder: 0.001 @@ -116,14 +117,14 @@ OPTIMIZER: ITERATION_NUM: ~ EPOCH_NUM: 160 START_EPOCH: 0 -UPDATE_STRATEGY: "general_updating" +UPDATE_STRATEGY: "general_strategy" LR_SCHEDULE: NAME: "adjusting_schedule" ARGS: - base_lr: 0.001 - base_duration: 2.0 - base_rate: 0.9 - - update_level: "batch_update" + # - update_level: "batch_update" # Validation details PER_ITERATION: ~ diff --git a/configs/mutual_channel_loss/mcl_vgg16.yml b/configs/mutual_channel_loss/mcl_vgg16.yml index 6d776e4..e4ee74d 100644 --- a/configs/mutual_channel_loss/mcl_vgg16.yml +++ b/configs/mutual_channel_loss/mcl_vgg16.yml @@ -102,7 +102,9 @@ TRANSFORMS: OPTIMIZER: NAME: SGD - MOMENTUM: 0.9 + ARGS: + - momentum: 0.9 + - weight_decay: 0.0005 LR: backbone: 0.001 encoder: 0.01 diff --git a/configs/progressive_multi_granularity_learning/pmg_resnet50.yml b/configs/progressive_multi_granularity_learning/pmg_resnet50.yml index f49bada..85e58b1 100644 --- a/configs/progressive_multi_granularity_learning/pmg_resnet50.yml +++ b/configs/progressive_multi_granularity_learning/pmg_resnet50.yml @@ -106,7 +106,9 @@ TRANSFORMS: OPTIMIZER: NAME: "SGD" - MOMENTUM: 0.9 + ARGS: + - momentum: 0.9 + - weight_decay: 0.0005 LR: backbone: 0.0002 encoder: 0.002 diff --git a/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml b/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml index 7bf908f..1c1b3f7 100644 --- a/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml +++ b/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml @@ -118,7 +118,9 @@ TRANSFORMS: OPTIMIZER: NAME: "SGD" - MOMENTUM: 0.9 + ARGS: + - momentum: 0.9 + - weight_decay: 0.0005 LR: backbone: 0.0005 encoder: ~ diff --git a/configs/resnet/resnet50.yml b/configs/resnet/resnet50.yml index 485d607..3c600e3 100644 --- a/configs/resnet/resnet50.yml +++ b/configs/resnet/resnet50.yml @@ -88,7 +88,9 @@ TRANSFORMS: OPTIMIZER: NAME: "SGD" - MOMENTUM: 0.9 + ARGS: + - momentum: 0.9 + - weight_decay: 0.0005 LR: backbone: 0.0002 encoder: 0.002 diff --git a/configs/resnet/resnet50_cutmix.yml b/configs/resnet/resnet50_cutmix.yml index 86f517f..9e12b2e 100644 --- a/configs/resnet/resnet50_cutmix.yml +++ b/configs/resnet/resnet50_cutmix.yml @@ -88,7 +88,9 @@ TRANSFORMS: OPTIMIZER: NAME: "SGD" - MOMENTUM: 0.9 + ARGS: + - momentum: 0.9 + - weight_decay: 0.0005 LR: backbone: 0.0002 encoder: 0.002 diff --git a/configs/swin_transformer/swin_transformer.yml b/configs/swin_transformer/swin_transformer.yml index fbb7f00..f648a56 100644 --- a/configs/swin_transformer/swin_transformer.yml +++ b/configs/swin_transformer/swin_transformer.yml @@ -13,17 +13,17 @@ DATASET: NAME: "CUB_200_2011" ROOT: "/mnt/sdb/data/wangxinran/dataset" TRAIN: - BATCH_SIZE: 8 + BATCH_SIZE: 16 POSITIVE: 0 PIN_MEMORY: True SHUFFLE: True - NUM_WORKERS: 2 + NUM_WORKERS: 4 TEST: - BATCH_SIZE: 8 + BATCH_SIZE: 16 POSITIVE: 0 PIN_MEMORY: False SHUFFLE: True - NUM_WORKERS: 1 + NUM_WORKERS: 4 MODEL: NAME: "SwinTransformer" @@ -117,20 +117,20 @@ TRANSFORMS: - 0.225 OPTIMIZER: - NAME: "SGD" - MOMENTUM: 0.9 + NAME: "AdamW" + ARGS: + - weight_decay: 0.0005 LR: - backbone: 0.005 - encoder: 0.005 - necks: 0.005 - heads: 0.005 - WEIGHT_DECAY: 0.0005 + backbone: 0.0001 + encoder: 0.0001 + necks: 0.0001 + heads: 0.0001 LR_SCHEDULE: NAME: "warmup_cosine_decay_schedule" ARGS: - warmup_steps: 800 - - max_lr: 0.005 + - max_lr: 0.0001 - max_epochs: 50 - decay_type: 1 diff --git a/fgvclib/apis/build.py b/fgvclib/apis/build.py index 9b192d3..15ecf68 100644 --- a/fgvclib/apis/build.py +++ b/fgvclib/apis/build.py @@ -26,6 +26,7 @@ from fgvclib.models.encoders import get_encoder from fgvclib.models.necks import get_neck from fgvclib.models.heads import get_head +from fgvclib.optimizers import get_optimizer from fgvclib.transforms import get_transform from fgvclib.utils.logger import get_logger, Logger from fgvclib.utils.interpreter import get_interpreter, Interpreter @@ -163,8 +164,8 @@ def build_optimizer(optim_cfg: CfgNode, model:t.Union[nn.Module, nn.DataParallel 'params': getattr(model, attr).parameters(), 'lr': optim_cfg.LR[attr] }) - - optimizer = optim.SGD(params=params, momentum=optim_cfg.MOMENTUM, weight_decay=optim_cfg.WEIGHT_DECAY) + + optimizer = get_optimizer(optim_cfg.NAME)(params, tltd(optim_cfg.ARGS)) return optimizer diff --git a/fgvclib/configs/config.py b/fgvclib/configs/config.py index c79b5aa..36383d4 100644 --- a/fgvclib/configs/config.py +++ b/fgvclib/configs/config.py @@ -112,8 +112,8 @@ def __init__(self): # Optimizer self.cfg.OPTIMIZER = CN() self.cfg.OPTIMIZER.NAME = "SGD" - self.cfg.OPTIMIZER.MOMENTUM = 0.9 - self.cfg.OPTIMIZER.WEIGHT_DECAY = 5e-4 + self.cfg.OPTIMIZER.ARGS = [{"momentum": 0.9}, {"weight_decay": 5e-4}] + self.cfg.OPTIMIZER.LR = CN() self.cfg.OPTIMIZER.LR.backbone = None self.cfg.OPTIMIZER.LR.encoder = None diff --git a/fgvclib/metrics/metrics.py b/fgvclib/metrics/metrics.py index 00845ed..8543302 100644 --- a/fgvclib/metrics/metrics.py +++ b/fgvclib/metrics/metrics.py @@ -84,7 +84,7 @@ def recall(name:str="recall(threshold=0.5)", top_k:int=None, threshold:float=0.5 top_k (int): Number of the highest probability or logit score predictions considered finding the correct label. - threshhold (float, optional): + threshold (float, optional): Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. diff --git a/fgvclib/models/sotas/cal.py b/fgvclib/models/sotas/cal.py index cf00c63..1ccc497 100644 --- a/fgvclib/models/sotas/cal.py +++ b/fgvclib/models/sotas/cal.py @@ -72,7 +72,7 @@ def __init__(self, cfg: CfgNode, backbone: nn.Module, encoder: nn.Module, necks: super().__init__(cfg, backbone, encoder, necks, heads, criterions) self.out_channels = 2048 - self.num_classes = cfg.NUM_CLASS + self.num_classes = cfg.CLASS_NUM self.M = 32 self.net = 'resnet101' self.register_buffer('feature_center', torch.zeros(self.num_classes, self.M * self.out_channels)) # 32 * 2048 @@ -170,3 +170,36 @@ def load_state_dict(self, state_dict, strict=True): model_dict.update(pretrained_dict) super(WSDAN_CAL, self).load_state_dict(model_dict) + + def infer_aux(self, x): + x_m = torch.flip(x, [3]) + + y_pred_raw, y_pred_aux_raw, _, attention_map = self.infer(x) + y_pred_raw_m, y_pred_aux_raw_m, _, attention_map_m = self.infer(x_m) + crop_images = batch_augment(x, attention_map, mode='crop', theta=0.3, padding_ratio=0.1) + y_pred_crop, y_pred_aux_crop, _, _ = self.infer(crop_images) + + crop_images2 = batch_augment(x, attention_map, mode='crop', theta=0.2, padding_ratio=0.1) + y_pred_crop2, y_pred_aux_crop2, _, _ = self.infer(crop_images2) + + crop_images3 = batch_augment(x, attention_map, mode='crop', theta=0.1, padding_ratio=0.05) + y_pred_crop3, y_pred_aux_crop3, _, _ = self.infer(crop_images3) + + crop_images_m = batch_augment(x_m, attention_map_m, mode='crop', theta=0.3, padding_ratio=0.1) + y_pred_crop_m, y_pred_aux_crop_m, _, _ = self.infer(crop_images_m) + + crop_images_m2 = batch_augment(x_m, attention_map_m, mode='crop', theta=0.2, padding_ratio=0.1) + y_pred_crop_m2, y_pred_aux_crop_m2, _, _ = self.infer(crop_images_m2) + + crop_images_m3 = batch_augment(x_m, attention_map_m, mode='crop', theta=0.1, padding_ratio=0.05) + y_pred_crop_m3, y_pred_aux_crop_m3, _, _ = self.infer(crop_images_m3) + + y_pred = (y_pred_raw + y_pred_crop + y_pred_crop2 + y_pred_crop3) / 4. + y_pred_m = (y_pred_raw_m + y_pred_crop_m + y_pred_crop_m2 + y_pred_crop_m3) / 4. + y_pred = (y_pred + y_pred_m) / 2. + + y_pred_aux = (y_pred_aux_raw + y_pred_aux_crop + y_pred_aux_crop2 + y_pred_aux_crop3) / 4. + y_pred_aux_m = (y_pred_aux_raw_m + y_pred_aux_crop_m + y_pred_aux_crop_m2 + y_pred_aux_crop_m3) / 4. + y_pred_aux = (y_pred_aux + y_pred_aux_m) / 2. + + return y_pred_aux diff --git a/fgvclib/models/sotas/swin_transformer.py b/fgvclib/models/sotas/swin_transformer.py index 53feaa3..38656f4 100644 --- a/fgvclib/models/sotas/swin_transformer.py +++ b/fgvclib/models/sotas/swin_transformer.py @@ -148,7 +148,7 @@ def forward(self, x, target=None): if self.lambda_s != 0: S = logits[name].size(1) logit = logits[name].view(-1, self.num_classes).contiguous() - loss_s = nn.CrossEntropyLoss()(logit, target.unsqueeze(1).repeat(1, S).flatten(0)) + loss_s = nn.CrossEntropyLoss()(logit.float(), target.unsqueeze(1).repeat(1, S).flatten(0)) losses.append(LossItem(name="loss_s", value=loss_s, weight=self.lambda_s)) elif "drop_" in name: @@ -160,7 +160,7 @@ def forward(self, x, target=None): logit = logits[name].view(-1, self.num_classes).contiguous() n_preds = nn.Tanh()(logit) labels_0 = (torch.zeros([batch_size * S, self.num_classes]) - 1).to(device) - loss_n = nn.MSELoss()(n_preds, labels_0) + loss_n = nn.MSELoss()(n_preds.float(), labels_0) losses.append(LossItem(name="loss_n", value=loss_n, weight=self.lambda_n)) @@ -169,7 +169,7 @@ def forward(self, x, target=None): raise ValueError("FPN not use here.") if self.lambda_b != 0: ### here using 'layer1'~'layer4' is default setting, you can change to your own - loss_b = nn.CrossEntropyLoss()(logits[name].mean(1), target) + loss_b = nn.CrossEntropyLoss()(logits[name].mean(1).float(), target) losses.append(LossItem(name="loss_b", value=loss_b, weight=self.lambda_b)) elif "comb_outs" in name: @@ -177,11 +177,11 @@ def forward(self, x, target=None): raise ValueError("Combiner not use here.") if self.lambda_c != 0: - loss_c = nn.CrossEntropyLoss()(logits[name], target) + loss_c = nn.CrossEntropyLoss()(logits[name].float(), target) losses.append(LossItem(name="loss_c", value=loss_c, weight=self.lambda_c)) elif "ori_out" in name: - loss_ori = F.cross_entropy(logits[name], target) + loss_ori = F.cross_entropy(logits[name].float(), target) losses.append(LossItem(name="loss_ori", value=loss_ori, weight=1.0)) return logits, losses diff --git a/fgvclib/utils/evaluate_function/evaluate_model.py b/fgvclib/utils/evaluate_function/evaluate_model.py index 8bb1a78..9dc6fb5 100644 --- a/fgvclib/utils/evaluate_function/evaluate_model.py +++ b/fgvclib/utils/evaluate_function/evaluate_model.py @@ -88,3 +88,41 @@ def swin_transformer_evaluate(model:nn.Module, p_bar:t.Iterable, metrics:t.List[ }) return results + +@evaluate_function("wsdan_cal_evaluate") +def wsdan_cal_evaluate(model:nn.Module, p_bar:t.Iterable, metrics:t.List[NamedMetric], use_cuda:bool=True) -> t.Dict: + r"""Evaluate the FGVC model. + + Args: + model (nn.Module): + The FGVC model. + p_bar (iterable): + A iterator provide test data. + metrics (List[NamedMetric]): + List of metrics. + use_cuda (boolean, optional): + Whether to use gpu. + + Returns: + dict: The result dict. + """ + + model.eval() + results = dict() + + with torch.no_grad(): + for _, (inputs, targets) in enumerate(p_bar): + if use_cuda: + inputs, targets = inputs.cuda(), targets.cuda() + inputs, targets = Variable(inputs), Variable(targets) + logits = model(inputs) + for metric in metrics: + _ = metric.update(logits["comb_outs"], targets) + + for metric in metrics: + result = metric.compute() + results.update({ + metric.name: round(result.item(), 3) + }) + + return results diff --git a/fgvclib/utils/lr_schedules/adjusting_schedule.py b/fgvclib/utils/lr_schedules/adjusting_schedule.py index 37ba952..523bce3 100644 --- a/fgvclib/utils/lr_schedules/adjusting_schedule.py +++ b/fgvclib/utils/lr_schedules/adjusting_schedule.py @@ -12,10 +12,10 @@ def __init__(self, optimizer, base_rate, base_duration, base_lr) -> None: self.base_lr = base_lr self.update_level = "batch" - def step(self, optimizer, batch_idx, current_epoch, total_batch, **kwargs): + def step(self, batch_idx, current_epoch, total_batch, **kwargs): iter = float(batch_idx) / total_batch lr = self.base_lr * pow(self.base_rate, (current_epoch + iter) / self.base_duration) - for pg in optimizer.param_groups: + for pg in self.optimizer.param_groups: pg['lr'] = lr diff --git a/main.py b/main.py index aa87ca0..b63c0f2 100644 --- a/main.py +++ b/main.py @@ -58,8 +58,6 @@ def train(cfg: CfgNode): train_sampler = build_sampler(sampler_cfg.TRAIN)(train_set, **tltd(sampler_cfg.TRAIN.ARGS)) test_sampler = build_sampler(sampler_cfg.TEST)(test_set, **tltd(sampler_cfg.TEST.ARGS)) - - train_loader = build_dataloader( dataset=train_set, mode_cfg=cfg.DATASET.TRAIN, From ce36f4521bade5e6793ef2b086d1e553a58eacc7 Mon Sep 17 00:00:00 2001 From: xin-ran-w Date: Sun, 5 Feb 2023 23:17:50 +0800 Subject: [PATCH 07/10] add optimizers --- configs/api_net/api_net_resnet101.yml | 157 ++++++++++++++++++++++++++ fgvclib/models/__init__.py | 0 fgvclib/optimizers/__init__.py | 38 +++++++ fgvclib/optimizers/adam.py | 9 ++ fgvclib/optimizers/adamw.py | 8 ++ fgvclib/optimizers/sgd.py | 8 ++ 6 files changed, 220 insertions(+) create mode 100644 configs/api_net/api_net_resnet101.yml create mode 100644 fgvclib/models/__init__.py create mode 100644 fgvclib/optimizers/__init__.py create mode 100644 fgvclib/optimizers/adam.py create mode 100644 fgvclib/optimizers/adamw.py create mode 100644 fgvclib/optimizers/sgd.py diff --git a/configs/api_net/api_net_resnet101.yml b/configs/api_net/api_net_resnet101.yml new file mode 100644 index 0000000..9911e37 --- /dev/null +++ b/configs/api_net/api_net_resnet101.yml @@ -0,0 +1,157 @@ +EXP_NAME: "APINet-Resnet101" + +RESUME_WEIGHT: ~ + +WEIGHT: + NAME: "APINet_resnet101_new.pth" + SAVE_DIR: "/mnt/sdb/data/wangxinran/weight/fgvclib" + +LOGGER: + NAME: "txt_logger" + +DATASET: + NAME: "CUB_200_2011" + ROOT: "/mnt/sdb/data/wangxinran/dataset" + TRAIN: + BATCH_SIZE: 2 + POSITIVE: 0 + PIN_MEMORY: True + SHUFFLE: True + NUM_WORKERS: 10 + TEST: + BATCH_SIZE: 100 + POSITIVE: 0 + PIN_MEMORY: True + SHUFFLE: False + NUM_WORKERS: 10 + +SAMPLER: + TRAIN: + NAME: "BalancedBatchSampler" + IS_BATCH_SAMPLER: True + ARGS: + - n_samples: 2 + - n_classes: 10 + TEST: + NAME: "SequentialSampler" + IS_BATCH_SAMPLER: False + ARGS: ~ + +MODEL: + NAME: "APINet" + CLASS_NUM: 200 + CRITERIONS: + - name: "score_rank_regular_loss" + args: [] + w: 1.0 + - name: "cross_entropy_loss" + args: [] + w: 1.0 + BACKBONE: + NAME: "resnet101" + ARGS: + - pretrained: True + - del_keys: [] + ENCODER: + NAME: "avg_pooling_2d" + ARGS: + - kernel_size: 14 + - stride: 1 + NECKS: + NAME: "pairwise_interaction" + ARGS: + - in_dim: 4096 + - hid_dim: 512 + - out_dim: 2048 + HEADS: + NAME: "classifier_drop_1fc" + ARGS: + - in_dim: + - 2048 + +TRANSFORMS: + TRAIN: + - name: "resize" + size: + - 512 + - 512 + - name: "random_crop" + size: + - 448 + - 448 + padding: ~ + - name: "random_horizontal_flip" + prob: 0.5 + - name: "to_tensor" + - name: "normalize" + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + TEST: + - name: "resize" + size: + - 512 + - 512 + - name: "center_crop" + size: + - 448 + - 448 + - name: "to_tensor" + - name: "normalize" + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + +OPTIMIZER: + NAME: "SGD" + ARGS: + - momentum: 0.9 + - weight_decay: 0.0005 + LR: + backbone: 0.002 + encoder: 0.01 + necks: 0.01 + heads: 0.01 + +ITERATION_NUM: ~ +EPOCH_NUM: 100 +START_EPOCH: 0 +UPDATE_STRATEGY: "general_strategy" + + +# Validation details +PER_ITERATION: ~ +PER_EPOCH: ~ +METRICS: + - name: "accuracy(topk=1)" + metric: "accuracy" + top_k: 1 + threshold: ~ + - name: "accuracy(topk=5)" + metric: "accuracy" + top_k: 5 + threshold: ~ + - name: "recall(threshold=0.5)" + metric: "recall" + top_k: 1 + threshold: 0.5 + - name: "precision(threshold=0.5)" + metric: "precision" + top_k: 1 + threshold: 0.5 + +INTERPRETER: + NAME: "cam" + METHOD: "gradcam" + TARGET_LAYERS: + - "backbone.layer4" diff --git a/fgvclib/models/__init__.py b/fgvclib/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fgvclib/optimizers/__init__.py b/fgvclib/optimizers/__init__.py new file mode 100644 index 0000000..7fa56e2 --- /dev/null +++ b/fgvclib/optimizers/__init__.py @@ -0,0 +1,38 @@ +import os +import importlib + + +__OPTMIZER_DICT__ = {} + + +def get_optimizer(name): + r"""Return the metric with the given name. + + Args: + name (str): + The name of metric. + + Return: + The metric contructor method. + """ + + return __OPTMIZER_DICT__[name] + +def get_optimizer(name): + return __OPTMIZER_DICT__[name] + +def optimizer(name): + + def register_function_fn(cls): + if name in __OPTMIZER_DICT__: + raise ValueError("Name %s already registered!" % name) + __OPTMIZER_DICT__[name] = cls + return cls + + return register_function_fn + +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith('.py') and not file.startswith('_'): + module_name = file[:file.find('.py')] + module = importlib.import_module('fgvclib.optimizers.' + module_name) + diff --git a/fgvclib/optimizers/adam.py b/fgvclib/optimizers/adam.py new file mode 100644 index 0000000..ec75178 --- /dev/null +++ b/fgvclib/optimizers/adam.py @@ -0,0 +1,9 @@ +from torch.optim import Adam + + +from .import optimizer + +@optimizer("Adam") +def adam(params, cfg): + return Adam(params=params, **cfg) + diff --git a/fgvclib/optimizers/adamw.py b/fgvclib/optimizers/adamw.py new file mode 100644 index 0000000..312b37c --- /dev/null +++ b/fgvclib/optimizers/adamw.py @@ -0,0 +1,8 @@ +from torch.optim import AdamW + + +from .import optimizer + +@optimizer("AdamW") +def adamw(params, cfg): + return AdamW(params=params, **cfg) \ No newline at end of file diff --git a/fgvclib/optimizers/sgd.py b/fgvclib/optimizers/sgd.py new file mode 100644 index 0000000..b618f6b --- /dev/null +++ b/fgvclib/optimizers/sgd.py @@ -0,0 +1,8 @@ +from torch.optim import SGD + + +from .import optimizer + +@optimizer("SGD") +def adamw(params, cfg): + return SGD(params=params, **cfg) \ No newline at end of file From 3f600665fb596f120a7654935e56acd5c5da677e Mon Sep 17 00:00:00 2001 From: xin-ran-w Date: Mon, 6 Feb 2023 22:29:08 +0800 Subject: [PATCH 08/10] update build_optimizer --- configs/api_net/api_net_resnet101.yml | 1 + configs/cal/cal_resnet101.yml | 1 + configs/mutual_channel_loss/mcl_vgg16.yml | 1 + .../pmg_resnet50.yml | 1 + .../pmg_v2_resnet50.yml | 1 + configs/resnet/resnet50.yml | 3 +- configs/resnet/resnet50_cutmix.yml | 3 +- configs/swin_transformer/swin_transformer.yml | 3 + configs/transfg/transFG_ViT_B_16.yml | 1 + fgvclib/apis/build.py | 56 ++++++++++++++----- fgvclib/configs/config.py | 1 + fgvclib/models/sotas/swin_transformer.py | 2 +- fgvclib/optimizers/adam.py | 4 +- fgvclib/optimizers/adamw.py | 4 +- fgvclib/optimizers/sgd.py | 4 +- .../lr_schedules/cosine_decay_schedule.py | 9 ++- .../update_swin_transformer.py | 25 +++++---- main.py | 6 +- 18 files changed, 86 insertions(+), 40 deletions(-) diff --git a/configs/api_net/api_net_resnet101.yml b/configs/api_net/api_net_resnet101.yml index 9911e37..c0320b4 100644 --- a/configs/api_net/api_net_resnet101.yml +++ b/configs/api_net/api_net_resnet101.yml @@ -118,6 +118,7 @@ OPTIMIZER: - momentum: 0.9 - weight_decay: 0.0005 LR: + base: 0.002 backbone: 0.002 encoder: 0.01 necks: 0.01 diff --git a/configs/cal/cal_resnet101.yml b/configs/cal/cal_resnet101.yml index f954df0..eaa8824 100644 --- a/configs/cal/cal_resnet101.yml +++ b/configs/cal/cal_resnet101.yml @@ -109,6 +109,7 @@ OPTIMIZER: - momentum: 0.9 - weight_decay: 0.0005 LR: + base: 0.001 backbone: 0.001 encoder: 0.001 necks: 0.001 diff --git a/configs/mutual_channel_loss/mcl_vgg16.yml b/configs/mutual_channel_loss/mcl_vgg16.yml index e4ee74d..005a9b8 100644 --- a/configs/mutual_channel_loss/mcl_vgg16.yml +++ b/configs/mutual_channel_loss/mcl_vgg16.yml @@ -106,6 +106,7 @@ OPTIMIZER: - momentum: 0.9 - weight_decay: 0.0005 LR: + base: 0.001 backbone: 0.001 encoder: 0.01 necks: 0.01 diff --git a/configs/progressive_multi_granularity_learning/pmg_resnet50.yml b/configs/progressive_multi_granularity_learning/pmg_resnet50.yml index 85e58b1..976a374 100644 --- a/configs/progressive_multi_granularity_learning/pmg_resnet50.yml +++ b/configs/progressive_multi_granularity_learning/pmg_resnet50.yml @@ -110,6 +110,7 @@ OPTIMIZER: - momentum: 0.9 - weight_decay: 0.0005 LR: + base: 0.0002 backbone: 0.0002 encoder: 0.002 necks: 0.002 diff --git a/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml b/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml index 1c1b3f7..ff26f49 100644 --- a/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml +++ b/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml @@ -122,6 +122,7 @@ OPTIMIZER: - momentum: 0.9 - weight_decay: 0.0005 LR: + base: 0.0005 backbone: 0.0005 encoder: ~ necks: 0.005 diff --git a/configs/resnet/resnet50.yml b/configs/resnet/resnet50.yml index 3c600e3..fdd9976 100644 --- a/configs/resnet/resnet50.yml +++ b/configs/resnet/resnet50.yml @@ -91,7 +91,8 @@ OPTIMIZER: ARGS: - momentum: 0.9 - weight_decay: 0.0005 - LR: + LR: + base: 0.0002 backbone: 0.0002 encoder: 0.002 necks: 0.002 diff --git a/configs/resnet/resnet50_cutmix.yml b/configs/resnet/resnet50_cutmix.yml index 9e12b2e..0930d5d 100644 --- a/configs/resnet/resnet50_cutmix.yml +++ b/configs/resnet/resnet50_cutmix.yml @@ -91,7 +91,8 @@ OPTIMIZER: ARGS: - momentum: 0.9 - weight_decay: 0.0005 - LR: + LR: + base: 0.0002 backbone: 0.0002 encoder: 0.002 necks: 0.002 diff --git a/configs/swin_transformer/swin_transformer.yml b/configs/swin_transformer/swin_transformer.yml index f648a56..728aa11 100644 --- a/configs/swin_transformer/swin_transformer.yml +++ b/configs/swin_transformer/swin_transformer.yml @@ -2,6 +2,8 @@ EXP_NAME: "SwinT" RESUME_WEIGHT: ~ +DISTRIBUTED: False + WEIGHT: NAME: "swinT.pth" SAVE_DIR: "/mnt/sdb/data/wangxinran/weight/fgvclib" @@ -121,6 +123,7 @@ OPTIMIZER: ARGS: - weight_decay: 0.0005 LR: + base: 0.0001 backbone: 0.0001 encoder: 0.0001 necks: 0.0001 diff --git a/configs/transfg/transFG_ViT_B_16.yml b/configs/transfg/transFG_ViT_B_16.yml index fd5c668..a946b58 100644 --- a/configs/transfg/transFG_ViT_B_16.yml +++ b/configs/transfg/transFG_ViT_B_16.yml @@ -120,6 +120,7 @@ OPTIMIZER: MOMENTUM: 0.9 WEIGHT_DECAY: 0.0 LR: + base: 0.08 backbone: 0.08 encoder: 0.08 necks: 0.08 diff --git a/fgvclib/apis/build.py b/fgvclib/apis/build.py index 15ecf68..7cbebba 100644 --- a/fgvclib/apis/build.py +++ b/fgvclib/apis/build.py @@ -145,28 +145,56 @@ def build_optimizer(optim_cfg: CfgNode, model:t.Union[nn.Module, nn.DataParallel Optimizer: A Pytorch Optimizer. """ - params= list() + params = list() model_attrs = ["backbone", "encoder", "necks", "heads"] - if isinstance(model, nn.DataParallel) or isinstance(model, nn.parallel.DistributedDataParallel): - for attr in model_attrs: - if getattr(model.module, attr) and optim_cfg.LR[attr]: - params.append({ - 'params': getattr(model.module, attr).parameters(), - 'lr': optim_cfg.LR[attr] - }) - print(attr, optim_cfg.LR[attr]) + # if isinstance(model, nn.DataParallel) or isinstance(model, nn.parallel.DistributedDataParallel): + # for attr in model_attrs: + # if getattr(model.module, attr) and optim_cfg.LR[attr]: + # params.append({ + # 'params': getattr(model.module, attr).parameters(), + # 'lr': optim_cfg.LR[attr] + # }) + + # else: + # for attr in model_attrs: + # if getattr(model, attr) and optim_cfg.LR[attr]: + # params.append({ + # 'params': getattr(model, attr).parameters(), + # 'lr': optim_cfg.LR[attr] + # }) + + + if isinstance(model, nn.DataParallel) or isinstance(model, nn.parallel.DistributedDataParallel): + m = model.module else: - for attr in model_attrs: - if getattr(model, attr) and optim_cfg.LR[attr]: + m = model + for n, p in m.named_parameters(): + is_other = True + if p.requires_grad: + for attr in model_attrs: + if n.__contains__(attr): + is_other = False + params.append({ + 'params': p, + 'lr': optim_cfg.LR[attr] + }) + + if is_other: params.append({ - 'params': getattr(model, attr).parameters(), - 'lr': optim_cfg.LR[attr] + 'params': p, + 'lr': optim_cfg.LR["base"] }) + - optimizer = get_optimizer(optim_cfg.NAME)(params, tltd(optim_cfg.ARGS)) + # for n, p in m.named_parameters(): + + # if n.__contains__() + + optimizer = get_optimizer(optim_cfg.NAME)(params, optim_cfg.LR.base, tltd(optim_cfg.ARGS)) + # optimizer = AdamW(params=params, lr=0.0001, weight_decay=5e-4) return optimizer def build_criterion(criterion_cfg: CfgNode) -> nn.Module: diff --git a/fgvclib/configs/config.py b/fgvclib/configs/config.py index 36383d4..07d4264 100644 --- a/fgvclib/configs/config.py +++ b/fgvclib/configs/config.py @@ -115,6 +115,7 @@ def __init__(self): self.cfg.OPTIMIZER.ARGS = [{"momentum": 0.9}, {"weight_decay": 5e-4}] self.cfg.OPTIMIZER.LR = CN() + self.cfg.OPTIMIZER.LR.base = None self.cfg.OPTIMIZER.LR.backbone = None self.cfg.OPTIMIZER.LR.encoder = None self.cfg.OPTIMIZER.LR.necks = None diff --git a/fgvclib/models/sotas/swin_transformer.py b/fgvclib/models/sotas/swin_transformer.py index 38656f4..ac204e9 100644 --- a/fgvclib/models/sotas/swin_transformer.py +++ b/fgvclib/models/sotas/swin_transformer.py @@ -184,4 +184,4 @@ def forward(self, x, target=None): loss_ori = F.cross_entropy(logits[name].float(), target) losses.append(LossItem(name="loss_ori", value=loss_ori, weight=1.0)) - return logits, losses + return logits, losses \ No newline at end of file diff --git a/fgvclib/optimizers/adam.py b/fgvclib/optimizers/adam.py index ec75178..f998c18 100644 --- a/fgvclib/optimizers/adam.py +++ b/fgvclib/optimizers/adam.py @@ -4,6 +4,6 @@ from .import optimizer @optimizer("Adam") -def adam(params, cfg): - return Adam(params=params, **cfg) +def adam(params, lr, cfg): + return Adam(params=params, lr=lr, **cfg) diff --git a/fgvclib/optimizers/adamw.py b/fgvclib/optimizers/adamw.py index 312b37c..8e6b069 100644 --- a/fgvclib/optimizers/adamw.py +++ b/fgvclib/optimizers/adamw.py @@ -4,5 +4,5 @@ from .import optimizer @optimizer("AdamW") -def adamw(params, cfg): - return AdamW(params=params, **cfg) \ No newline at end of file +def adamw(params, lr, cfg): + return AdamW(params=params, lr=lr, **cfg) \ No newline at end of file diff --git a/fgvclib/optimizers/sgd.py b/fgvclib/optimizers/sgd.py index b618f6b..6d31d8d 100644 --- a/fgvclib/optimizers/sgd.py +++ b/fgvclib/optimizers/sgd.py @@ -4,5 +4,5 @@ from .import optimizer @optimizer("SGD") -def adamw(params, cfg): - return SGD(params=params, **cfg) \ No newline at end of file +def sgd(params, lr, cfg): + return SGD(params=params, lr=lr, **cfg) \ No newline at end of file diff --git a/fgvclib/utils/lr_schedules/cosine_decay_schedule.py b/fgvclib/utils/lr_schedules/cosine_decay_schedule.py index ff6687c..c51c55e 100644 --- a/fgvclib/utils/lr_schedules/cosine_decay_schedule.py +++ b/fgvclib/utils/lr_schedules/cosine_decay_schedule.py @@ -11,7 +11,7 @@ def __init__(self, optimizer, max_epochs, warmup_steps, max_lr, batch_num_per_ep total_batchs = max_epochs * batch_num_per_epoch iters = np.arange(total_batchs - warmup_steps) - self.last_iter = -1 + self.update_level = 'batch' if decay_type == 1: @@ -26,12 +26,11 @@ def __init__(self, optimizer, max_epochs, warmup_steps, max_lr, batch_num_per_ep schedule = np.concatenate((warmup_lr_schedule, schedule)) self.schedule = schedule - self.step() - def step(self): - self.last_iter += 1 + def step(self, iteration): + for pg in self.optimizer.param_groups: - pg['lr'] = self.schedule[self.last_iter] + pg['lr'] = self.schedule[iteration] @lr_schedule("warmup_cosine_decay_schedule") diff --git a/fgvclib/utils/update_function/update_swin_transformer.py b/fgvclib/utils/update_function/update_swin_transformer.py index 31a1be8..1534b0c 100644 --- a/fgvclib/utils/update_function/update_swin_transformer.py +++ b/fgvclib/utils/update_function/update_swin_transformer.py @@ -23,6 +23,13 @@ def update_swin_transformer(model: nn.Module, optimizer: Optimizer, pbar:Iterabl for batch_id, (inputs, targets) in enumerate(pbar): model.train() + if lr_schedule.update_level == "batch": + iteration = epoch * len(pbar) + batch_id + lr_schedule.step(iteration) + + if lr_schedule.update_level == "epoch": + lr_schedule.step() + if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() inputs, targets = Variable(inputs), Variable(targets) @@ -44,26 +51,24 @@ def update_swin_transformer(model: nn.Module, optimizer: Optimizer, pbar:Iterabl """ out, losses = model(inputs, targets) total_loss = compute_loss_value(losses) - total_loss /= model.update_freq - + if hasattr(model, 'update_freq'): + update_freq = model.update_freq + else: + update_freq = model.module.update_freq + total_loss /= update_freq + if amp: scaler.scale(total_loss).backward() else: total_loss.backward() - + losses_info = detach_loss_value(losses) losses_info.update({"iter_loss": total_loss.item()}) pbar.set_postfix(losses_info) - if (batch_id + 1) % model.update_freq == 0: + if (batch_id + 1) % update_freq == 0: if amp: scaler.step(optimizer) scaler.update() # next batch else: optimizer.step() optimizer.zero_grad() - - if lr_schedule.update_level == "batch": - lr_schedule.step() - - if lr_schedule.update_level == "epoch": - lr_schedule.step() \ No newline at end of file diff --git a/main.py b/main.py index b63c0f2..05f44c4 100644 --- a/main.py +++ b/main.py @@ -62,14 +62,16 @@ def train(cfg: CfgNode): dataset=train_set, mode_cfg=cfg.DATASET.TRAIN, sampler=train_sampler, - is_batch_sampler=sampler_cfg.TRAIN.IS_BATCH_SAMPLER + # is_batch_sampler=sampler_cfg.TRAIN.IS_BATCH_SAMPLER + is_batch_sampler=False ) test_loader = build_dataloader( dataset=test_set, mode_cfg=cfg.DATASET.TEST, sampler=test_sampler, - is_batch_sampler=sampler_cfg.TEST.IS_BATCH_SAMPLER + # is_batch_sampler=sampler_cfg.TEST.IS_BATCH_SAMPLER + is_batch_sampler=False ) optimizer = build_optimizer(cfg.OPTIMIZER, model) From 30b5fdef628127166427de28ef7e2dd0fd314ff2 Mon Sep 17 00:00:00 2001 From: xin-ran-w Date: Mon, 6 Feb 2023 22:42:47 +0800 Subject: [PATCH 09/10] update config --- configs/mutual_channel_loss/mcl_vgg16.yml | 7 +------ .../pmg_resnet50.yml | 5 ----- .../pmg_v2_resnet50.yml | 6 ------ configs/resnet/resnet50.yml | 2 +- configs/swin_transformer/swin_transformer.yml | 7 +------ configs/transfg/transFG_ViT_B_16.yml | 6 ------ 6 files changed, 3 insertions(+), 30 deletions(-) diff --git a/configs/mutual_channel_loss/mcl_vgg16.yml b/configs/mutual_channel_loss/mcl_vgg16.yml index 005a9b8..9f5492c 100644 --- a/configs/mutual_channel_loss/mcl_vgg16.yml +++ b/configs/mutual_channel_loss/mcl_vgg16.yml @@ -137,9 +137,4 @@ METRICS: metric: "precision" top_k: ~ threshold: 0.5 - -INTERPRETER: - NAME: "cam" - METHOD: "gradcam" - TARGET_LAYERS: - - "layer4" \ No newline at end of file + \ No newline at end of file diff --git a/configs/progressive_multi_granularity_learning/pmg_resnet50.yml b/configs/progressive_multi_granularity_learning/pmg_resnet50.yml index 976a374..f0ae46a 100644 --- a/configs/progressive_multi_granularity_learning/pmg_resnet50.yml +++ b/configs/progressive_multi_granularity_learning/pmg_resnet50.yml @@ -143,8 +143,3 @@ METRICS: top_k: ~ threshold: 0.5 -INTERPRETER: - NAME: "cam" - METHOD: "gradcam" - TARGET_LAYERS: - - "layer4" diff --git a/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml b/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml index ff26f49..c9b5635 100644 --- a/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml +++ b/configs/progressive_multi_granularity_learning/pmg_v2_resnet50.yml @@ -153,9 +153,3 @@ METRICS: metric: "precision" top_k: ~ threshold: 0.5 - -INTERPRETER: - NAME: "cam" - METHOD: "gradcam" - TARGET_LAYERS: - - "layer4" \ No newline at end of file diff --git a/configs/resnet/resnet50.yml b/configs/resnet/resnet50.yml index fdd9976..717f113 100644 --- a/configs/resnet/resnet50.yml +++ b/configs/resnet/resnet50.yml @@ -99,7 +99,7 @@ OPTIMIZER: heads: 0.002 ITERATION_NUM: ~ -EPOCH_NUM: 1 +EPOCH_NUM: 50 START_EPOCH: 0 UPDATE_STRATEGY: "general_strategy" diff --git a/configs/swin_transformer/swin_transformer.yml b/configs/swin_transformer/swin_transformer.yml index 728aa11..bd228cc 100644 --- a/configs/swin_transformer/swin_transformer.yml +++ b/configs/swin_transformer/swin_transformer.yml @@ -165,9 +165,4 @@ METRICS: metric: "precision" top_k: ~ threshold: 0.5 - -INTERPRETER: - NAME: "cam" - METHOD: "gradcam" - TARGET_LAYERS: - - "layer4" \ No newline at end of file + \ No newline at end of file diff --git a/configs/transfg/transFG_ViT_B_16.yml b/configs/transfg/transFG_ViT_B_16.yml index a946b58..a59551c 100644 --- a/configs/transfg/transFG_ViT_B_16.yml +++ b/configs/transfg/transFG_ViT_B_16.yml @@ -157,9 +157,3 @@ METRICS: metric: "precision" top_k: ~ threshold: 0.5 - -INTERPRETER: - NAME: "cam" - METHOD: "gradcam" - TARGET_LAYERS: - - "layer4" \ No newline at end of file From 88142403ed3901245b1420db55b21139033cd0ac Mon Sep 17 00:00:00 2001 From: xin-ran-w Date: Mon, 6 Feb 2023 22:49:40 +0800 Subject: [PATCH 10/10] fix multi-gpu bug --- main.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 05f44c4..c2c9720 100644 --- a/main.py +++ b/main.py @@ -49,12 +49,13 @@ def train(cfg: CfgNode): ) model.to(device) + sampler_cfg = cfg.SAMPLER if cfg.DISTRIBUTED: model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True, device_ids=[cfg.GPU]) train_sampler = torch.utils.data.distributed.DistributedSampler(train_set) test_sampler = torch.utils.data.distributed.DistributedSampler(test_set) + sampler_cfg.TRAIN.IS_BATCH_SAMPLER = False else: - sampler_cfg = cfg.SAMPLER train_sampler = build_sampler(sampler_cfg.TRAIN)(train_set, **tltd(sampler_cfg.TRAIN.ARGS)) test_sampler = build_sampler(sampler_cfg.TEST)(test_set, **tltd(sampler_cfg.TEST.ARGS)) @@ -62,16 +63,13 @@ def train(cfg: CfgNode): dataset=train_set, mode_cfg=cfg.DATASET.TRAIN, sampler=train_sampler, - # is_batch_sampler=sampler_cfg.TRAIN.IS_BATCH_SAMPLER - is_batch_sampler=False + is_batch_sampler=sampler_cfg.TRAIN.IS_BATCH_SAMPLER ) test_loader = build_dataloader( dataset=test_set, mode_cfg=cfg.DATASET.TEST, sampler=test_sampler, - # is_batch_sampler=sampler_cfg.TEST.IS_BATCH_SAMPLER - is_batch_sampler=False ) optimizer = build_optimizer(cfg.OPTIMIZER, model)