Skip to content

Commit

Permalink
Added option --dry-run (#322)
Browse files Browse the repository at this point in the history
* Added option --dry-run
  • Loading branch information
David Vilar authored and fhieber committed Mar 9, 2018
1 parent 618e813 commit ed7e3ef
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Each version section may have have subsections for: _Added_, _Changed_, _Removed
### Added
- Added a flag `--fixed-param-names` to prevent certain parameters from being optimized during training.
This is useful if you want to keep pre-trained embeddings fixed during training.
- Added a flag `--dry-run` to `sockeye.train` to not perform any actual training, but print statistics about the model
and mode of operation.

## [1.17.3]
### Changed
Expand Down
5 changes: 5 additions & 0 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,11 @@ def add_training_args(params):
default=-1,
help='Keep only the last n params files, use -1 to keep all files. Default: %(default)s')

train_params.add_argument('--dry-run',
action='store_true',
help="Do not perform any actual training, but print statistics about the model"
" and mode of operation.")


def add_train_cli_args(params):
add_training_io_args(params)
Expand Down
8 changes: 8 additions & 0 deletions sockeye/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import shutil
import sys
import tempfile
from contextlib import ExitStack
from typing import Any, cast, Optional, Dict, List, Tuple

Expand Down Expand Up @@ -724,6 +725,13 @@ def main():
arguments.add_train_cli_args(params)
args = params.parse_args()

if args.dry_run:
# Modify arguments so that we write to a temporary directory and
# perform 0 training iterations
temp_dir = tempfile.TemporaryDirectory() # Will be automatically removed
args.output = temp_dir.name
args.max_updates = 0

utils.seedRNGs(args.seed)

check_arg_compatibility(args)
Expand Down
3 changes: 2 additions & 1 deletion test/unit/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def test_model_parameters(test_params, expected_params):
decode_and_evaluate_use_cpu=False,
decode_and_evaluate_device_id=None,
seed=13,
keep_last_params=-1)),
keep_last_params=-1,
dry_run=False)),
])
def test_training_arg(test_params, expected_params):
_test_args(test_params, expected_params, arguments.add_training_args)
Expand Down

0 comments on commit ed7e3ef

Please sign in to comment.