Skip to content

Commit

Permalink
Stop after n checkpoints (#661)
Browse files Browse the repository at this point in the history
This change add the ability to stop training after a given number of checkpoints.
Useful for situations where one is required to share scarce GPU resources.
Added a new command line argument '--max-checkpoints'
  • Loading branch information
tuglat authored and fhieber committed Mar 14, 2019
1 parent 06394b6 commit 2e39633
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 2 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.18.85]
### Fixed
- Added parameter to force training to stop after a given number of checkpoints. Useful when forced to share limited GPU resources.

## [1.18.84]
### Fixed
- Fixed lexical constraints bugs that broke batching and caused large drop in BLEU.
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.18.84'
__version__ = '1.18.85'
6 changes: 6 additions & 0 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,12 @@ def add_training_args(params):
help='Maximum number of checkpoints the model is allowed to not improve in '
'<optimized-metric> on validation data before training is stopped. '
'Default: %(default)s.')
train_params.add_argument('--max-checkpoints',
type=int,
default=None,
help='Maximum number of checkpoints to continue training the model '
'before training is stopped. '
'Default: %(default)s.')
train_params.add_argument('--min-num-epochs',
type=int,
default=None,
Expand Down
2 changes: 1 addition & 1 deletion sockeye/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@
"keep_last_params", "seed",
"max_updates", "min_updates",
"max_num_epochs", "min_num_epochs",
"max_samples", "min_samples"]
"max_samples", "min_samples", "max_checkpoints"]

# Other argument constants
TRAINING_ARG_SOURCE = "--source"
Expand Down
2 changes: 2 additions & 0 deletions sockeye/image_captioning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def train(args: argparse.Namespace):
if min_epochs is not None and max_epochs is not None:
check_condition(min_epochs <= max_epochs,
"Minimum number of epochs must be smaller than maximum number of epochs")

# Fixed training schedule always runs for a set number of updates
if args.learning_rate_schedule:
min_updates = None
Expand Down Expand Up @@ -380,6 +381,7 @@ def train(args: argparse.Namespace):
metrics=args.metrics,
checkpoint_interval=args.checkpoint_interval,
max_num_not_improved=max_num_checkpoint_not_improved,
max_checkpoints=args.max_checkpoints,
min_samples=min_samples,
max_samples=max_samples,
min_updates=min_updates,
Expand Down
2 changes: 2 additions & 0 deletions sockeye/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,7 @@ def train(args: argparse.Namespace) -> training.TrainState:
if min_epochs is not None and max_epochs is not None:
check_condition(min_epochs <= max_epochs,
"Minimum number of epochs must be smaller than maximum number of epochs")

# Fixed training schedule always runs for a set number of updates
if args.learning_rate_schedule:
min_updates = None
Expand All @@ -921,6 +922,7 @@ def train(args: argparse.Namespace) -> training.TrainState:
metrics=args.metrics,
checkpoint_interval=args.checkpoint_interval,
max_num_not_improved=max_num_checkpoint_not_improved,
max_checkpoints=args.max_checkpoints,
min_samples=min_samples,
max_samples=max_samples,
min_updates=min_updates,
Expand Down
9 changes: 9 additions & 0 deletions sockeye/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ def fit(self,
metrics: List[str],
checkpoint_interval: int,
max_num_not_improved: int,
max_checkpoints: Optional[int] = None,
min_samples: Optional[int] = None,
max_samples: Optional[int] = None,
min_updates: Optional[int] = None,
Expand All @@ -478,6 +479,9 @@ def fit(self,
:param max_num_not_improved: Stop training if early_stopping_metric did not improve for this many checkpoints.
Use -1 to disable stopping based on early_stopping_metric.
:param max_checkpoints: Stop training after this many checkpoints.
Use None to disable.
:param min_samples: Optional minimum number of samples.
:param max_samples: Optional maximum number of samples.
:param min_updates: Optional minimum number of update steps.
Expand Down Expand Up @@ -537,6 +541,11 @@ def fit(self,
speedometer = Speedometer(frequency=C.MEASURE_SPEED_EVERY, auto_reset=False)
tic = time.time()

if max_checkpoints is not None:
max_updates = self.state.updates + max_checkpoints * checkpoint_interval
logger.info(("Resetting max_updates to %d + %d * %d = %d in order to implement stopping after (an additional) %d checkpoints."
% (self.state.updates, max_checkpoints, checkpoint_interval, max_updates, max_checkpoints)))

next_data_batch = train_iter.next()
while True:

Expand Down
1 change: 1 addition & 0 deletions test/unit/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def test_model_parameters(test_params, expected_params):
optimized_metric=C.PERPLEXITY,
checkpoint_interval=4000,
max_num_checkpoint_not_improved=32,
max_checkpoints=None,
embed_dropout=(.0, .0),
transformer_dropout_attention=0.1,
transformer_dropout_act=0.1,
Expand Down

0 comments on commit 2e39633

Please sign in to comment.