From 2e39633ee6ec0db06a2c6b2b8676f3ba53c43913 Mon Sep 17 00:00:00 2001 From: Steven Bradtke Date: Thu, 14 Mar 2019 12:20:09 -0400 Subject: [PATCH] Stop after n checkpoints (#661) This change add the ability to stop training after a given number of checkpoints. Useful for situations where one is required to share scarce GPU resources. Added a new command line argument '--max-checkpoints' --- CHANGELOG.md | 4 ++++ sockeye/__init__.py | 2 +- sockeye/arguments.py | 6 ++++++ sockeye/constants.py | 2 +- sockeye/image_captioning/train.py | 2 ++ sockeye/train.py | 2 ++ sockeye/training.py | 9 +++++++++ test/unit/test_arguments.py | 1 + 8 files changed, 26 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c2e5c9a8..6b08bad86 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [1.18.85] +### Fixed +- Added parameter to force training to stop after a given number of checkpoints. Useful when forced to share limited GPU resources. + ## [1.18.84] ### Fixed - Fixed lexical constraints bugs that broke batching and caused large drop in BLEU. diff --git a/sockeye/__init__.py b/sockeye/__init__.py index d28a904eb..ea9f50655 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '1.18.84' +__version__ = '1.18.85' diff --git a/sockeye/arguments.py b/sockeye/arguments.py index 704b4a4dc..6e00933c0 100644 --- a/sockeye/arguments.py +++ b/sockeye/arguments.py @@ -876,6 +876,12 @@ def add_training_args(params): help='Maximum number of checkpoints the model is allowed to not improve in ' ' on validation data before training is stopped. ' 'Default: %(default)s.') + train_params.add_argument('--max-checkpoints', + type=int, + default=None, + help='Maximum number of checkpoints to continue training the model ' + 'before training is stopped. ' + 'Default: %(default)s.') train_params.add_argument('--min-num-epochs', type=int, default=None, diff --git a/sockeye/constants.py b/sockeye/constants.py index 5281ade2a..34811ff3f 100644 --- a/sockeye/constants.py +++ b/sockeye/constants.py @@ -258,7 +258,7 @@ "keep_last_params", "seed", "max_updates", "min_updates", "max_num_epochs", "min_num_epochs", - "max_samples", "min_samples"] + "max_samples", "min_samples", "max_checkpoints"] # Other argument constants TRAINING_ARG_SOURCE = "--source" diff --git a/sockeye/image_captioning/train.py b/sockeye/image_captioning/train.py index ee8445bfa..97f7207c4 100644 --- a/sockeye/image_captioning/train.py +++ b/sockeye/image_captioning/train.py @@ -351,6 +351,7 @@ def train(args: argparse.Namespace): if min_epochs is not None and max_epochs is not None: check_condition(min_epochs <= max_epochs, "Minimum number of epochs must be smaller than maximum number of epochs") + # Fixed training schedule always runs for a set number of updates if args.learning_rate_schedule: min_updates = None @@ -380,6 +381,7 @@ def train(args: argparse.Namespace): metrics=args.metrics, checkpoint_interval=args.checkpoint_interval, max_num_not_improved=max_num_checkpoint_not_improved, + max_checkpoints=args.max_checkpoints, min_samples=min_samples, max_samples=max_samples, min_updates=min_updates, diff --git a/sockeye/train.py b/sockeye/train.py index b5e72304c..548904d97 100644 --- a/sockeye/train.py +++ b/sockeye/train.py @@ -897,6 +897,7 @@ def train(args: argparse.Namespace) -> training.TrainState: if min_epochs is not None and max_epochs is not None: check_condition(min_epochs <= max_epochs, "Minimum number of epochs must be smaller than maximum number of epochs") + # Fixed training schedule always runs for a set number of updates if args.learning_rate_schedule: min_updates = None @@ -921,6 +922,7 @@ def train(args: argparse.Namespace) -> training.TrainState: metrics=args.metrics, checkpoint_interval=args.checkpoint_interval, max_num_not_improved=max_num_checkpoint_not_improved, + max_checkpoints=args.max_checkpoints, min_samples=min_samples, max_samples=max_samples, min_updates=min_updates, diff --git a/sockeye/training.py b/sockeye/training.py index acaf39d84..7e4f06d7b 100644 --- a/sockeye/training.py +++ b/sockeye/training.py @@ -452,6 +452,7 @@ def fit(self, metrics: List[str], checkpoint_interval: int, max_num_not_improved: int, + max_checkpoints: Optional[int] = None, min_samples: Optional[int] = None, max_samples: Optional[int] = None, min_updates: Optional[int] = None, @@ -478,6 +479,9 @@ def fit(self, :param max_num_not_improved: Stop training if early_stopping_metric did not improve for this many checkpoints. Use -1 to disable stopping based on early_stopping_metric. + :param max_checkpoints: Stop training after this many checkpoints. + Use None to disable. + :param min_samples: Optional minimum number of samples. :param max_samples: Optional maximum number of samples. :param min_updates: Optional minimum number of update steps. @@ -537,6 +541,11 @@ def fit(self, speedometer = Speedometer(frequency=C.MEASURE_SPEED_EVERY, auto_reset=False) tic = time.time() + if max_checkpoints is not None: + max_updates = self.state.updates + max_checkpoints * checkpoint_interval + logger.info(("Resetting max_updates to %d + %d * %d = %d in order to implement stopping after (an additional) %d checkpoints." + % (self.state.updates, max_checkpoints, checkpoint_interval, max_updates, max_checkpoints))) + next_data_batch = train_iter.next() while True: diff --git a/test/unit/test_arguments.py b/test/unit/test_arguments.py index 20c30dffb..326a0f4bf 100644 --- a/test/unit/test_arguments.py +++ b/test/unit/test_arguments.py @@ -140,6 +140,7 @@ def test_model_parameters(test_params, expected_params): optimized_metric=C.PERPLEXITY, checkpoint_interval=4000, max_num_checkpoint_not_improved=32, + max_checkpoints=None, embed_dropout=(.0, .0), transformer_dropout_attention=0.1, transformer_dropout_act=0.1,