Skip to content

Commit

Permalink
Proper reloading of learning rate scheduler state at training resumpt…
Browse files Browse the repository at this point in the history
…ion (#369)

fixes #366
  • Loading branch information
fhieber authored Apr 25, 2018
1 parent 12fdcaa commit 6adb291
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 11 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.18.6]
### Fixed
- Fixed a problem with learning rate scheduler not properly being loaded when resuming training.

## [1.18.5]
### Fixed
- Fixed a problem with trainer not waiting for the last checkpoint decoder (#367).
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.18.5'
__version__ = '1.18.6'
13 changes: 8 additions & 5 deletions sockeye/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,21 @@ def __init__(self,
kvstore: str,
initializer: mx.initializer.Initializer,
gradient_clipping_type: str,
gradient_clipping_threshold: Optional[float],
lr_scheduler: Optional[LearningRateScheduler]) -> None:
gradient_clipping_threshold: Optional[float]) -> None:
super().__init__()
self.name = name
self.params = params
self.kvstore = kvstore
self.initializer = initializer
self.gradient_clipping_type = gradient_clipping_type
self.gradient_clipping_threshold = gradient_clipping_threshold
self.lr_scheduler = lr_scheduler
if lr_scheduler is not None:
self.params["lr_scheduler"] = lr_scheduler

@property
def lr_scheduler(self) -> Optional[LearningRateScheduler]:
return self.params.get("lr_scheduler", None)

def set_lr_scheduler(self, lr_scheduler: Optional[LearningRateScheduler]):
self.params["lr_scheduler"] = lr_scheduler


class SockeyeOptimizer(mx.optimizer.Optimizer):
Expand Down
4 changes: 2 additions & 2 deletions sockeye/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,8 @@ def create_optimizer_config(args: argparse.Namespace, source_vocab_sizes: List[i
kvstore=args.kvstore,
initializer=weight_init,
gradient_clipping_type=gradient_clipping_type,
gradient_clipping_threshold=gradient_clipping_threshold,
lr_scheduler=lr_sched)
gradient_clipping_threshold=gradient_clipping_threshold)
config.set_lr_scheduler(lr_sched)
logger.info("Optimizer: %s", config)
logger.info("Gradient Compression: %s", gradient_compression_params(args))
return config
Expand Down
7 changes: 4 additions & 3 deletions sockeye/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,8 @@ def _check_args(self,
utils.check_condition(self.optimizer_config.name != C.OPTIMIZER_EVE,
"Eve optimizer not supported with distributed training.")
utils.check_condition(
not issubclass(type(self.optimizer_config.lr_scheduler), lr_scheduler.AdaptiveLearningRateScheduler),
not issubclass(type(self.optimizer_config.lr_scheduler),
lr_scheduler.AdaptiveLearningRateScheduler),
"Adaptive learning rate schedulers not supported with a dist kvstore. "
"Try a fixed schedule such as %s." % C.LR_SCHEDULER_FIXED_RATE_INV_SQRT_T)
utils.check_condition(not lr_decay_param_reset, "Parameter reset when the learning rate decays not "
Expand Down Expand Up @@ -987,9 +988,9 @@ def _load_training_state(self, train_iter: data_io.BaseParallelSampleIter):

# (6) Learning rate scheduler
with open(os.path.join(self.training_state_dirname, C.SCHEDULER_STATE_NAME), "rb") as fp:
self.optimizer_config.lr_scheduler = pickle.load(fp)
self.optimizer_config.set_lr_scheduler(pickle.load(fp))
# initialize optimizer again
self.model.initialize_optimizer(self.optimizer_config)
self._initialize_optimizer()


class TensorboardLogger:
Expand Down

0 comments on commit 6adb291

Please sign in to comment.