From 6adb291449231a5e5377ee1cb3e250a220300631 Mon Sep 17 00:00:00 2001 From: Felix Hieber Date: Wed, 25 Apr 2018 17:38:52 +0200 Subject: [PATCH] Proper reloading of learning rate scheduler state at training resumption (#369) fixes #366 --- CHANGELOG.md | 4 ++++ sockeye/__init__.py | 2 +- sockeye/optimizers.py | 13 ++++++++----- sockeye/train.py | 4 ++-- sockeye/training.py | 7 ++++--- 5 files changed, 19 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 58b6dd7fb..63065be59 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [1.18.6] +### Fixed +- Fixed a problem with learning rate scheduler not properly being loaded when resuming training. + ## [1.18.5] ### Fixed - Fixed a problem with trainer not waiting for the last checkpoint decoder (#367). diff --git a/sockeye/__init__.py b/sockeye/__init__.py index 0f678d58e..0c689cd9d 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '1.18.5' +__version__ = '1.18.6' diff --git a/sockeye/optimizers.py b/sockeye/optimizers.py index 418fad7ca..81137b170 100644 --- a/sockeye/optimizers.py +++ b/sockeye/optimizers.py @@ -38,8 +38,7 @@ def __init__(self, kvstore: str, initializer: mx.initializer.Initializer, gradient_clipping_type: str, - gradient_clipping_threshold: Optional[float], - lr_scheduler: Optional[LearningRateScheduler]) -> None: + gradient_clipping_threshold: Optional[float]) -> None: super().__init__() self.name = name self.params = params @@ -47,9 +46,13 @@ def __init__(self, self.initializer = initializer self.gradient_clipping_type = gradient_clipping_type self.gradient_clipping_threshold = gradient_clipping_threshold - self.lr_scheduler = lr_scheduler - if lr_scheduler is not None: - self.params["lr_scheduler"] = lr_scheduler + + @property + def lr_scheduler(self) -> Optional[LearningRateScheduler]: + return self.params.get("lr_scheduler", None) + + def set_lr_scheduler(self, lr_scheduler: Optional[LearningRateScheduler]): + self.params["lr_scheduler"] = lr_scheduler class SockeyeOptimizer(mx.optimizer.Optimizer): diff --git a/sockeye/train.py b/sockeye/train.py index eed7d6478..83fa7657d 100644 --- a/sockeye/train.py +++ b/sockeye/train.py @@ -714,8 +714,8 @@ def create_optimizer_config(args: argparse.Namespace, source_vocab_sizes: List[i kvstore=args.kvstore, initializer=weight_init, gradient_clipping_type=gradient_clipping_type, - gradient_clipping_threshold=gradient_clipping_threshold, - lr_scheduler=lr_sched) + gradient_clipping_threshold=gradient_clipping_threshold) + config.set_lr_scheduler(lr_sched) logger.info("Optimizer: %s", config) logger.info("Gradient Compression: %s", gradient_compression_params(args)) return config diff --git a/sockeye/training.py b/sockeye/training.py index e8aa1ff5b..b813b2d65 100644 --- a/sockeye/training.py +++ b/sockeye/training.py @@ -886,7 +886,8 @@ def _check_args(self, utils.check_condition(self.optimizer_config.name != C.OPTIMIZER_EVE, "Eve optimizer not supported with distributed training.") utils.check_condition( - not issubclass(type(self.optimizer_config.lr_scheduler), lr_scheduler.AdaptiveLearningRateScheduler), + not issubclass(type(self.optimizer_config.lr_scheduler), + lr_scheduler.AdaptiveLearningRateScheduler), "Adaptive learning rate schedulers not supported with a dist kvstore. " "Try a fixed schedule such as %s." % C.LR_SCHEDULER_FIXED_RATE_INV_SQRT_T) utils.check_condition(not lr_decay_param_reset, "Parameter reset when the learning rate decays not " @@ -987,9 +988,9 @@ def _load_training_state(self, train_iter: data_io.BaseParallelSampleIter): # (6) Learning rate scheduler with open(os.path.join(self.training_state_dirname, C.SCHEDULER_STATE_NAME), "rb") as fp: - self.optimizer_config.lr_scheduler = pickle.load(fp) + self.optimizer_config.set_lr_scheduler(pickle.load(fp)) # initialize optimizer again - self.model.initialize_optimizer(self.optimizer_config) + self._initialize_optimizer() class TensorboardLogger: