Skip to content

Commit

Permalink
add ability to ignore any parameters whose name starts with some prefix
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 9, 2023
1 parent e747c4d commit a3d9583
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
8 changes: 8 additions & 0 deletions ema_pytorch/ema_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
min_value = 0.0,
param_or_buffer_names_no_ema = set(),
ignore_names = set(),
ignore_startswith_names = set()
):
super().__init__()
self.beta = beta
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema # parameter or buffer

self.ignore_names = ignore_names
self.ignore_startswith_names = ignore_startswith_names

self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0]))
Expand Down Expand Up @@ -138,6 +140,9 @@ def update_moving_average(self, ma_model, current_model):
if name in self.ignore_names:
continue

if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
continue

if name in self.param_or_buffer_names_no_ema:
ma_params.data.copy_(current_params.data)
continue
Expand All @@ -150,6 +155,9 @@ def update_moving_average(self, ma_model, current_model):
if name in self.ignore_names:
continue

if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
continue

if name in self.param_or_buffer_names_no_ema:
ma_buffer.data.copy_(current_buffer.data)
continue
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ema-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.2',
version = '0.1.4',
license='MIT',
description = 'Easy way to keep track of exponential moving average version of your pytorch module',
author = 'Phil Wang',
Expand Down

0 comments on commit a3d9583

Please sign in to comment.