Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Make bn weight decay configurable (#65)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: fairinternal/ClassyVision#65

Make the bn weight decay configurable, for some datasets it might be desirable to turn it off.

Reviewed By: vreis

Differential Revision: D20140487

fbshipit-source-id: 77debf2c4600a080081668565d70b7a3ddc788f4
  • Loading branch information
Aaron Adcock authored and facebook-github-bot committed Mar 2, 2020
1 parent e47a18d commit 553c18c
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions classy_vision/models/resnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def __init__(
base_width_and_cardinality: Optional[Union[Tuple, List]] = None,
basic_layer: bool = False,
final_bn_relu: bool = True,
bn_weight_decay: Optional[bool] = False,
):
"""
Implementation of `ResNeXt <https://arxiv.org/pdf/1611.05431.pdf>`_.
Expand All @@ -251,6 +252,7 @@ def __init__(
assert all(is_pos_int(n) for n in num_blocks)
assert is_pos_int(init_planes) and is_pos_int(reduction)
assert type(small_input) == bool
assert type(bn_weight_decay) == bool
assert (
type(zero_init_bn_residuals) == bool
), "zero_init_bn_residuals must be a boolean, set to true if gamma of last\
Expand All @@ -262,9 +264,11 @@ def __init__(
and is_pos_int(base_width_and_cardinality[1])
)

# we apply weight decay to batch norm if the model is a ResNeXt and we don't if
# it is a ResNet
self.bn_weight_decay = base_width_and_cardinality is not None
# Chooses whether to apply weight decay to batch norm
# parameters. This improves results in some situations,
# e.g. ResNeXt models trained / evaluated using the Imagenet
# dataset, but can cause worse performance in other scenarios
self.bn_weight_decay = bn_weight_decay

# initial convolutional block:
self.num_blocks = num_blocks
Expand Down Expand Up @@ -374,6 +378,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
"basic_layer": config.get("basic_layer", False),
"final_bn_relu": config.get("final_bn_relu", True),
"zero_init_bn_residuals": config.get("zero_init_bn_residuals", False),
"bn_weight_decay": config.get("bn_weight_decay", False),
}
return cls(**config)

Expand Down Expand Up @@ -476,6 +481,12 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


# Note, the ResNeXt models all have weight decay enabled for the batch
# norm parameters. We have found empirically that this gives better
# results when training on ImageNet (~0.5pp of top-1 acc) and brings
# our results on track with reported ImageNet results...but for
# training on other datasets, we have observed losses in accuracy (for
# example, the dataset used in https://arxiv.org/abs/1805.00932).
@register_model("resnext50_32x4d")
class ResNeXt50(ResNeXt):
def __init__(self):
Expand All @@ -484,6 +495,7 @@ def __init__(self):
basic_layer=False,
zero_init_bn_residuals=True,
base_width_and_cardinality=(4, 32),
bn_weight_decay=True,
)

@classmethod
Expand All @@ -499,6 +511,7 @@ def __init__(self):
basic_layer=False,
zero_init_bn_residuals=True,
base_width_and_cardinality=(4, 32),
bn_weight_decay=True,
)

@classmethod
Expand All @@ -514,6 +527,7 @@ def __init__(self):
basic_layer=False,
zero_init_bn_residuals=True,
base_width_and_cardinality=(4, 32),
bn_weight_decay=True,
)

@classmethod
Expand Down

0 comments on commit 553c18c

Please sign in to comment.