Skip to content

Commit

Permalink
Added constant mapping for metric min/max. To be used by average.py a…
Browse files Browse the repository at this point in the history
…s well, making the --max arg redundant (#104)
  • Loading branch information
fhieber authored Aug 11, 2017
1 parent 4863800 commit 533fae1
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 52 deletions.
16 changes: 7 additions & 9 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,23 @@ def add_average_args(params):
nargs="+",
help="either a single model directory (automatic checkpoint selection) "
"or multiple .params files (manual checkpoint selection)")
average_params.add_argument(
"--max", action="store_true", help="Maximize metric.")
average_params.add_argument(
"--metric",
help="Name of the metric to choose n-best checkpoints from. (default: {})".format(C.PERPLEXITY),
help="Name of the metric to choose n-best checkpoints from. Default: %(default)s.",
default=C.PERPLEXITY,
choices=[C.PERPLEXITY, C.BLEU])
choices=C.METRICS)
average_params.add_argument(
"-n",
type=int,
default=4,
help="number of checkpoints to find (default: 4)")
help="number of checkpoints to find. Default: %(default)s.")
average_params.add_argument(
"--output", "-o", required=True, type=str, help="output param file")
"--output", "-o", required=True, type=str, help="File to write averaged parameters to.")
average_params.add_argument(
"--strategy",
choices=["best", "last", "lifespan"],
default="best",
help="selection method (default: best)")
help="selection method. Default: %(default)s.")


def add_io_args(params):
Expand Down Expand Up @@ -338,8 +336,8 @@ def add_training_args(params):
choices=[C.PERPLEXITY, C.ACCURACY],
help='Names of metrics to track on training and validation data. Default: %(default)s.')
train_params.add_argument('--optimized-metric',
default='perplexity',
choices=[C.PERPLEXITY, C.ACCURACY, C.BLEU],
default=C.PERPLEXITY,
choices=C.METRICS,
help='Metric to optimize with early stopping {%(choices)s}. '
'Default: %(default)s.')

Expand Down
40 changes: 20 additions & 20 deletions sockeye/average.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,14 @@
import argparse
import itertools
import os
from typing import Dict, Iterable, Tuple, List
from typing import Dict, Iterable

import mxnet as mx

import sockeye.constants as C
import sockeye.utils
import sockeye.arguments
from . import arguments
from . import constants as C
from . import utils
from sockeye.log import setup_main_logger, log_sockeye_version
from sockeye.utils import check_condition


logger = setup_main_logger(__name__, console=True, file_logging=False)

Expand All @@ -46,33 +44,32 @@ def average(param_paths: Iterable[str]) -> Dict[str, mx.nd.NDArray]:
all_aux_params = []
for path in param_paths:
logger.info("Loading parameters from '%s'", path)
arg_params, aux_params = sockeye.utils.load_params(path)
arg_params, aux_params = utils.load_params(path)
all_arg_params.append(arg_params)
all_aux_params.append(aux_params)

logger.info("%d models loaded", len(all_arg_params))
check_condition(all(all_arg_params[0].keys() == p.keys() for p in all_arg_params),
"arg_param names do not match across models")
check_condition(all(all_aux_params[0].keys() == p.keys() for p in all_aux_params),
"aux_param names do not match across models")
utils.check_condition(all(all_arg_params[0].keys() == p.keys() for p in all_arg_params),
"arg_param names do not match across models")
utils.check_condition(all(all_aux_params[0].keys() == p.keys() for p in all_aux_params),
"aux_param names do not match across models")

avg_params = {}
# average arg_params
for k in all_arg_params[0]:
arrays = [p[k] for p in all_arg_params]
avg_params["arg:" + k] = sockeye.utils.average_arrays(arrays)
avg_params["arg:" + k] = utils.average_arrays(arrays)
# average aux_params
for k in all_aux_params[0]:
arrays = [p[k] for p in all_aux_params]
avg_params["aux:" + k] = sockeye.utils.average_arrays(arrays)
avg_params["aux:" + k] = utils.average_arrays(arrays)

return avg_params


def find_checkpoints(model_path: str, size=4, strategy="best", maximize=False, metric: str = C.PERPLEXITY) \
-> Iterable[str]:
def find_checkpoints(model_path: str, size=4, strategy="best", metric: str = C.PERPLEXITY) -> Iterable[str]:
"""
Finds N best points from .metrics file according to strategy
Finds N best points from .metrics file according to strategy.
:param metric: Metric according to which checkpoints are selected. Corresponds to columns in model/metrics file.
:param model_path: Path to model.
Expand All @@ -81,8 +78,9 @@ def find_checkpoints(model_path: str, size=4, strategy="best", maximize=False, m
:param maximize: Whether the value of the metric should be maximized.
:return: List of paths corresponding to chosen checkpoints.
"""
maximize = C.METRIC_MAXIMIZE[metric]
metrics_path = os.path.join(model_path, C.METRICS_NAME)
points = sockeye.utils.read_metrics_points(metrics_path, model_path, metric=metric)
points = utils.read_metrics_points(metrics_path, model_path, metric=metric)

if strategy == "best":
# N best scoring points
Expand Down Expand Up @@ -151,14 +149,16 @@ def main():
"""
log_sockeye_version(logger)
params = argparse.ArgumentParser(description="Averages parameters from multiple models.")
sockeye.arguments.add_average_args(params)
arguments.add_average_args(params)
args = params.parse_args()

if len(args.inputs) > 1:
avg_params = average(args.inputs)
else:
param_paths = find_checkpoints(args.inputs[0], args.n, args.strategy,
args.max, args.metric)
param_paths = find_checkpoints(model_path=args.inputs[0],
size=args.n,
strategy=args.strategy,
metric=args.metric)
avg_params = average(param_paths)

mx.nd.save(args.output, avg_params)
Expand Down
27 changes: 11 additions & 16 deletions sockeye/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import sockeye.checkpoint_decoder
import sockeye.constants as C
import sockeye.inference
import sockeye.utils

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -80,21 +81,12 @@ def __init__(self,
self.speedometer = mx.callback.Speedometer(batch_size=batch_size,
frequent=C.MEASURE_SPEED_EVERY,
auto_reset=False)
sockeye.utils.check_condition(optimized_metric in C.METRICS, "Unsupported metric: %s" % optimized_metric)
if optimized_metric == C.BLEU:
sockeye.utils.check_condition(self.checkpoint_decoder is not None, "%s requires CheckpointDecoder" % C.BLEU)
self.optimized_metric = optimized_metric
if self.optimized_metric == C.PERPLEXITY:
self.minimize = True
self.validation_best = np.inf
elif self.optimized_metric == C.ACCURACY:
self.minimize = False
self.validation_best = -np.inf
elif self.optimized_metric == C.BLEU:
assert self.checkpoint_decoder is not None, "BLEU requires CheckpointDecoder"
self.minimize = False
self.validation_best = -np.inf
else:
raise ValueError("No other metrics supported")
logger.info("Early stopping by optimizing '%s' (minimize=%s)",
self.optimized_metric, self.minimize)
self.validation_best = C.METRIC_WORST[self.optimized_metric]
logger.info("Early stopping by optimizing '%s'", self.optimized_metric)
self.tic = 0

def get_best_checkpoint(self) -> int:
Expand All @@ -109,8 +101,11 @@ def get_best_validation_score(self) -> float:
"""
return self.validation_best

def _is_better(self, value):
return value < self.validation_best if self.minimize else value > self.validation_best
def _is_better(self, value: float) -> bool:
if C.METRIC_MAXIMIZE[self.optimized_metric]:
return value > self.validation_best
else:
return value < self.validation_best

def batch_end_callback(self, epoch: int, nbatch: int, metric: mx.metric.EvalMetric):
"""
Expand Down
6 changes: 5 additions & 1 deletion sockeye/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Defines various constants used througout the project
"""
import mxnet as mx
import numpy as np

BOS_SYMBOL = "<s>"
EOS_SYMBOL = "</s>"
Expand Down Expand Up @@ -162,10 +163,13 @@
BATCH_MAJOR = "NTC"
TIME_MAJOR = "TNC"

# metric names
# metrics
ACCURACY = 'accuracy'
PERPLEXITY = 'perplexity'
BLEU = 'bleu'
METRICS = [PERPLEXITY, ACCURACY, BLEU]
METRIC_MAXIMIZE = {ACCURACY: True, BLEU: True, PERPLEXITY: False}
METRIC_WORST = {ACCURACY: 0.0, BLEU: 0.0, PERPLEXITY: np.inf}

# loss names
CROSS_ENTROPY = 'cross-entropy'
Expand Down
13 changes: 7 additions & 6 deletions test/unit/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
import pytest
import numpy as np
import sockeye.callback
import sockeye.utils
import tempfile

test_constants = [('perplexity', np.inf, True,
test_constants = [('perplexity', np.inf,
[{'perplexity': 100.0, '_': 42}, {'perplexity': 50.0}, {'perplexity': 60.0}, {'perplexity': 80.0}],
[{'perplexity': 200.0}, {'perplexity': 100.0}, {'perplexity': 100.001}, {'perplexity': 99.99}],
[True, True, False, True]),
('accuracy', -np.inf, False,
('accuracy', 0.0,
[{'accuracy': 100.0}, {'accuracy': 50.0}, {'accuracy': 60.0}, {'accuracy': 80.0}],
[{'accuracy': 200.0}, {'accuracy': 100.0}, {'accuracy': 100.001}, {'accuracy': 99.99}],
[True, False, False, False])]
Expand All @@ -38,17 +39,16 @@ def get_name_value(self):
yield metric_name, value


@pytest.mark.parametrize("optimized_metric, initial_best, minimize, train_metrics, eval_metrics, improved_seq",
@pytest.mark.parametrize("optimized_metric, initial_best, train_metrics, eval_metrics, improved_seq",
test_constants)
def test_callback(optimized_metric, initial_best, minimize, train_metrics, eval_metrics, improved_seq):
def test_callback(optimized_metric, initial_best, train_metrics, eval_metrics, improved_seq):
with tempfile.TemporaryDirectory() as tmpdir:
batch_size = 32
monitor = sockeye.callback.TrainingMonitor(batch_size=batch_size,
output_folder=tmpdir,
optimized_metric=optimized_metric)
assert monitor.optimized_metric == optimized_metric
assert monitor.get_best_validation_score() == initial_best
assert monitor.minimize == minimize

for checkpoint, (train_metric, eval_metric, expected_improved) in enumerate(
zip(train_metrics, eval_metrics, improved_seq), 1):
Expand All @@ -61,8 +61,9 @@ def test_callback(optimized_metric, initial_best, minimize, train_metrics, eval_


def test_bleu_requires_checkpoint_decoder():
with pytest.raises(AssertionError), tempfile.TemporaryDirectory() as tmpdir:
with pytest.raises(sockeye.utils.SockeyeError) as e, tempfile.TemporaryDirectory() as tmpdir:
sockeye.callback.TrainingMonitor(batch_size=1,
output_folder=tmpdir,
optimized_metric='bleu',
checkpoint_decoder=None)
assert "bleu requires CheckpointDecoder" == str(e.value)

0 comments on commit 533fae1

Please sign in to comment.