From da1da48eb6bb8285b28277f1dd06ca30ffbe3dfe Mon Sep 17 00:00:00 2001 From: ZhangYiqin <312065559@qq.com> Date: Mon, 4 Sep 2023 13:11:16 +0800 Subject: [PATCH] [Enhance] Add iTPN Supports for Non-three channel image (#1735) * Add channel argments to mae_head When trying iTPN pretrain, it only supports images with 3 channels. One of the restrictions is from MAEHead. * Transfer other argments from iTPNHiViT to HiViT The HiViT supports specifying channels, but the iTPNHiViT class can't pass channel argments to it. This is one of the reasons that iTPNHiViT implementation only support images with 3 channels. * Update itpn.py Fix hint problem --- mmpretrain/models/heads/mae_head.py | 23 +++++++++++++---------- mmpretrain/models/selfsup/itpn.py | 5 ++++- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/mmpretrain/models/heads/mae_head.py b/mmpretrain/models/heads/mae_head.py index 1a5366d13b5..b76ecedd96d 100644 --- a/mmpretrain/models/heads/mae_head.py +++ b/mmpretrain/models/heads/mae_head.py @@ -14,15 +14,18 @@ class MAEPretrainHead(BaseModule): norm_pix_loss (bool): Whether or not normalize target. Defaults to False. patch_size (int): Patch size. Defaults to 16. + in_channels (int): Number of input channels. Defaults to 3. """ def __init__(self, loss: dict, norm_pix: bool = False, - patch_size: int = 16) -> None: + patch_size: int = 16, + in_channels: int = 3) -> None: super().__init__() self.norm_pix = norm_pix self.patch_size = patch_size + self.in_channels = in_channels self.loss_module = MODELS.build(loss) def patchify(self, imgs: torch.Tensor) -> torch.Tensor: @@ -30,19 +33,19 @@ def patchify(self, imgs: torch.Tensor) -> torch.Tensor: Args: imgs (torch.Tensor): A batch of images. The shape should - be :math:`(B, 3, H, W)`. + be :math:`(B, C, H, W)`. Returns: torch.Tensor: Patchified images. The shape is - :math:`(B, L, \text{patch_size}^2 \times 3)`. + :math:`(B, L, \text{patch_size}^2 \times C)`. """ p = self.patch_size assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 h = w = imgs.shape[2] // p - x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = imgs.reshape(shape=(imgs.shape[0], self.in_channels, h, p, w, p)) x = torch.einsum('nchpwq->nhwpqc', x) - x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.in_channels)) return x def unpatchify(self, x: torch.Tensor) -> torch.Tensor: @@ -50,18 +53,18 @@ def unpatchify(self, x: torch.Tensor) -> torch.Tensor: Args: x (torch.Tensor): The shape is - :math:`(B, L, \text{patch_size}^2 \times 3)`. + :math:`(B, L, \text{patch_size}^2 \times C)`. Returns: - torch.Tensor: The shape is :math:`(B, 3, H, W)`. + torch.Tensor: The shape is :math:`(B, C, H, W)`. """ p = self.patch_size h = w = int(x.shape[1]**.5) assert h * w == x.shape[1] - x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = x.reshape(shape=(x.shape[0], h, w, p, p, self.in_channels)) x = torch.einsum('nhwpqc->nchpwq', x) - imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + imgs = x.reshape(shape=(x.shape[0], self.in_channels, h * p, h * p)) return imgs def construct_target(self, target: torch.Tensor) -> torch.Tensor: @@ -71,7 +74,7 @@ def construct_target(self, target: torch.Tensor) -> torch.Tensor: normalize the image according to ``norm_pix``. Args: - target (torch.Tensor): Image with the shape of B x 3 x H x W + target (torch.Tensor): Image with the shape of B x C x H x W Returns: torch.Tensor: Tokenized images with the shape of B x L x C diff --git a/mmpretrain/models/selfsup/itpn.py b/mmpretrain/models/selfsup/itpn.py index 85efd254053..488a9963182 100644 --- a/mmpretrain/models/selfsup/itpn.py +++ b/mmpretrain/models/selfsup/itpn.py @@ -64,6 +64,7 @@ def __init__( layer_scale_init_value: float = 0.0, mask_ratio: float = 0.75, reconstruction_type: str = 'pixel', + **kwargs, ): super().__init__( arch=arch, @@ -80,7 +81,9 @@ def __init__( norm_cfg=norm_cfg, ape=ape, rpe=rpe, - layer_scale_init_value=layer_scale_init_value) + layer_scale_init_value=layer_scale_init_value, + **kwargs, + ) self.pos_embed.requires_grad = False self.mask_ratio = mask_ratio