Skip to content

Commit

Permalink
[BugFixed] fix wrong trunc_normal_init use (open-mmlab#6432)
Browse files Browse the repository at this point in the history
* fix wrong trunc_normal_init use

* fix wrong trunc_normal_init use
  • Loading branch information
vealocia authored Nov 8, 2021
1 parent 1ab934e commit 6cf9aa1
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
7 changes: 3 additions & 4 deletions mmdet/models/backbones/pvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
constant_init, normal_init, trunc_normal_init)
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import MultiheadAttention
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner import (BaseModule, ModuleList, Sequential, _load_checkpoint,
load_state_dict)
from torch.nn.modules.utils import _pair as to_2tuple
Expand Down Expand Up @@ -275,7 +276,7 @@ def __init__(self, pos_shape, pos_dim, drop_rate=0., init_cfg=None):
self.drop = nn.Dropout(p=drop_rate)

def init_weights(self):
trunc_normal_init(self.pos_embed, std=0.02)
trunc_normal_(self.pos_embed, std=0.02)

def resize_pos_embed(self, pos_embed, input_shape, mode='bilinear'):
"""Resize pos_embed weights.
Expand Down Expand Up @@ -486,9 +487,7 @@ def init_weights(self):
f'training start from scratch')
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m.weight, std=.02)
if m.bias is not None:
constant_init(m.bias, 0)
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
Expand Down
9 changes: 4 additions & 5 deletions mmdet/models/backbones/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init
from mmcv.cnn.bricks.transformer import FFN, build_dropout
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
from mmcv.utils import to_2tuple

Expand Down Expand Up @@ -74,7 +75,7 @@ def __init__(self,
self.softmax = nn.Softmax(dim=-1)

def init_weights(self):
trunc_normal_init(self.relative_position_bias_table, std=0.02)
trunc_normal_(self.relative_position_bias_table, std=0.02)

def forward(self, x, mask=None):
"""
Expand Down Expand Up @@ -672,12 +673,10 @@ def init_weights(self):
f'{self.__class__.__name__}, '
f'training start from scratch')
if self.use_abs_pos_embed:
trunc_normal_init(self.absolute_pos_embed, std=0.02)
trunc_normal_(self.absolute_pos_embed, std=0.02)
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m.weight, std=.02)
if m.bias is not None:
constant_init(m.bias, 0)
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
Expand Down

0 comments on commit 6cf9aa1

Please sign in to comment.