Skip to content

Commit

Permalink
[Enhancement] Support hrnet frozen stage (#743)
Browse files Browse the repository at this point in the history
* support hrnet frozen stage

* support hrnet frozen stage
  • Loading branch information
sshuair authored Aug 3, 2021
1 parent 52b4fa5 commit f934084
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 0 deletions.
33 changes: 33 additions & 0 deletions mmseg/models/backbones/hrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ class HRNet(BaseModule):
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
pretrained (str, optional): model pretrained path. Default: None
Expand Down Expand Up @@ -285,6 +287,7 @@ def __init__(self,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False,
with_cp=False,
frozen_stages=-1,
zero_init_residual=False,
pretrained=None,
init_cfg=None):
Expand Down Expand Up @@ -315,6 +318,7 @@ def __init__(self,
self.norm_cfg = norm_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.frozen_stages = frozen_stages

# stem net
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
Expand Down Expand Up @@ -388,6 +392,8 @@ def __init__(self,
self.stage4, pre_stage_channels = self._make_stage(
self.stage4_cfg, num_channels)

self._freeze_stages()

@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
Expand Down Expand Up @@ -534,6 +540,32 @@ def _make_stage(self, layer_config, in_channels, multiscale_output=True):

return Sequential(*hr_modules), in_channels

def _freeze_stages(self):
"""Freeze stages param and norm stats."""
if self.frozen_stages >= 0:

self.norm1.eval()
self.norm2.eval()
for m in [self.conv1, self.norm1, self.conv2, self.norm2]:
for param in m.parameters():
param.requires_grad = False

for i in range(1, self.frozen_stages + 1):
if i == 1:
m = getattr(self, f'layer{i}')
t = getattr(self, f'transition{i}')
elif i == 4:
m = getattr(self, f'stage{i}')
else:
m = getattr(self, f'stage{i}')
t = getattr(self, f'transition{i}')
m.eval()
for param in m.parameters():
param.requires_grad = False
t.eval()
for param in t.parameters():
param.requires_grad = False

def forward(self, x):
"""Forward function."""

Expand Down Expand Up @@ -575,6 +607,7 @@ def train(self, mode=True):
"""Convert the model into training mode will keeping the normalization
layer freezed."""
super(HRNet, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
Expand Down
63 changes: 63 additions & 0 deletions tests/test_models/test_backbones/test_hrnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from mmcv.utils.parrots_wrapper import _BatchNorm

from mmseg.models.backbones import HRNet


def test_hrnet_backbone():
# Test HRNET with two stage frozen

extra = dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(32, 64)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(32, 64, 128)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(32, 64, 128, 256)))
frozen_stages = 2
model = HRNet(extra, frozen_stages=frozen_stages)
model.init_weights()
model.train()
assert model.norm1.training is False

for layer in [model.conv1, model.norm1]:
for param in layer.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
if i == 1:
layer = getattr(model, f'layer{i}')
transition = getattr(model, f'transition{i}')
elif i == 4:
layer = getattr(model, f'stage{i}')
else:
layer = getattr(model, f'stage{i}')
transition = getattr(model, f'transition{i}')

for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False

for mod in transition.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in transition.parameters():
assert param.requires_grad is False

0 comments on commit f934084

Please sign in to comment.