Skip to content

Commit

Permalink
Enable embedding logging to tensorboard (#350)
Browse files Browse the repository at this point in the history
  • Loading branch information
fhieber committed Apr 10, 2018
1 parent ba67443 commit cce1acc
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 14 deletions.
18 changes: 18 additions & 0 deletions sockeye/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,4 +220,22 @@ def _get_embed_weights(self, prefix: str) -> Tuple[mx.sym.Symbol, mx.sym.Symbol,
self.decoder.get_num_hidden()))
w_out_target = w_embed_target

self._embed_weight_source_name = w_embed_source.name
self._embed_weight_target_name = w_embed_target.name
self._out_weight_target_name = w_out_target.name
return w_embed_source, w_embed_target, w_out_target

def get_source_embed_params(self) -> Optional[mx.nd.NDArray]:
if self.params is None:
return None
return self.params.get(self._embed_weight_source_name)

def get_target_embed_params(self) -> Optional[mx.nd.NDArray]:
if self.params is None:
return None
return self.params.get(self._embed_weight_target_name)

def get_output_embed_params(self) -> Optional[mx.nd.NDArray]:
if self.params is None:
return None
return self.params.get(self._out_weight_target_name)
4 changes: 3 additions & 1 deletion sockeye/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,9 @@ def main():

trainer = training.EarlyStoppingTrainer(model=training_model,
optimizer_config=create_optimizer_config(args, source_vocab_sizes),
max_params_files_to_keep=args.keep_last_params)
max_params_files_to_keep=args.keep_last_params,
source_vocabs=source_vocabs,
target_vocab=target_vocab)

trainer.fit(train_iter=train_iter,
validation_iter=eval_iter,
Expand Down
50 changes: 41 additions & 9 deletions sockeye/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from . import lr_scheduler
from . import model
from . import utils
from . import vocab
from .optimizers import BatchState, CheckpointState, SockeyeOptimizer, OptimizerConfig

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -409,16 +410,22 @@ class EarlyStoppingTrainer:
:param model: TrainingModel instance.
:param optimizer_config: The optimizer configuration.
:param max_params_files_to_keep: Maximum number of params files to keep in the output folder (last n are kept).
:param source_vocabs: Source vocabulary (and optional source factor vocabularies).
:param target_vocab: Target vocabulary.
"""

def __init__(self,
model: TrainingModel,
optimizer_config: OptimizerConfig,
max_params_files_to_keep: int) -> None:
max_params_files_to_keep: int,
source_vocabs: List[vocab.Vocab],
target_vocab: vocab.Vocab) -> None:
self.model = model
self.optimizer_config = optimizer_config
self.max_params_files_to_keep = max_params_files_to_keep
self.tflogger = TensorboardLogger(logdir=os.path.join(model.output_dir, C.TENSORBOARD_NAME))
self.tflogger = TensorboardLogger(logdir=os.path.join(model.output_dir, C.TENSORBOARD_NAME),
source_vocab=source_vocabs[0],
target_vocab=target_vocab)
self.state = None # type: Optional[TrainState]

def fit(self,
Expand Down Expand Up @@ -710,9 +717,11 @@ def _update_metrics(self,

tf_metrics = checkpoint_metrics.copy()
tf_metrics.update({"%s_grad" % n: v for n, v in self.state.gradients.items()})
arg_params, aux_params = self.model.module.get_params()
tf_metrics.update(arg_params)
self.tflogger.log_metrics(tf_metrics, self.state.checkpoint)
tf_metrics.update(self.model.params)
self.tflogger.log_metrics(metrics=tf_metrics, checkpoint=self.state.checkpoint)
self.tflogger.log_source_embedding(self.model.get_source_embed_params(), self.state.checkpoint)
self.tflogger.log_target_embedding(self.model.get_target_embed_params(), self.state.checkpoint)
self.tflogger.log_output_embedding(self.model.get_output_embed_params(), self.state.checkpoint)

def _cleanup(self, lr_decay_opt_states_reset: str, process_manager: Optional['DecoderProcessManager'] = None):
"""
Expand Down Expand Up @@ -968,17 +977,27 @@ def _load_training_state(self, train_iter: data_io.BaseParallelSampleIter):
class TensorboardLogger:
"""
Thin wrapper for MXBoard API to log training events.
Flushes logging events to disk every 60 seconds.
:param logdir: Directory to write Tensorboard event files to.
:param source_vocab: Optional source vocabulary to log source embeddings.
:param target_vocab: Optional target vocabulary to log target and output embeddings.
"""

def __init__(self, logdir: str) -> None:
def __init__(self,
logdir: str,
source_vocab: Optional[vocab.Vocab] = None,
target_vocab: Optional[vocab.Vocab] = None) -> None:
self.logdir = logdir
self.source_labels = vocab.get_ordered_tokens_from_vocab(source_vocab) if source_vocab is not None else None
self.target_labels = vocab.get_ordered_tokens_from_vocab(target_vocab) if source_vocab is not None else None
try:
import mxboard
logger.info("Logging training events for Tensorboard at '%s'", self.logdir)
if os.path.exists(self.logdir):
logger.info("Deleting existing Tensorboard log directory '%s'", self.logdir)
shutil.rmtree(self.logdir)
self.sw = mxboard.SummaryWriter(logdir=self.logdir)
self.sw = mxboard.SummaryWriter(logdir=self.logdir, flush_secs=60, verbose=False)
except ImportError:
logger.info("mxboard not found. Consider 'pip install mxboard' to log events to Tensorboard.")
self.sw = None
Expand All @@ -992,13 +1011,26 @@ def log_metrics(self, metrics: Dict[str, Union[float, int, mx.nd.NDArray]], chec
self.sw.add_histogram(tag=name, values=value, bins=100, global_step=checkpoint)
else:
self.sw.add_scalar(tag=name, value=value, global_step=checkpoint)
self.sw.flush()

def log_graph(self, symbol: mx.sym.Symbol):
if self.sw is None:
return
self.sw.add_graph(symbol)
self.sw.flush()

def log_source_embedding(self, embedding: mx.nd.NDArray, checkpoint: int):
if self.sw is None or self.source_labels is None:
return
self.sw.add_embedding(tag="source", embedding=embedding, labels=self.source_labels, global_step=checkpoint)

def log_target_embedding(self, embedding: mx.nd.NDArray, checkpoint: int):
if self.sw is None or self.target_labels is None:
return
self.sw.add_embedding(tag="target", embedding=embedding, labels=self.target_labels, global_step=checkpoint)

def log_output_embedding(self, embedding: mx.nd.NDArray, checkpoint: int):
if self.sw is None or self.target_labels is None:
return
self.sw.add_embedding(tag="output", embedding=embedding, labels=self.target_labels, global_step=checkpoint)


class Speedometer:
Expand Down
14 changes: 12 additions & 2 deletions sockeye/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections import Counter
from contextlib import ExitStack
from itertools import chain, islice
from typing import Dict, Iterable, List, Mapping, Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple

from . import utils
from . import constants as C
Expand Down Expand Up @@ -229,7 +229,7 @@ def load_or_create_vocabs(source_paths: List[str],
return [vocab_source] + vocab_source_factors, vocab_target


def reverse_vocab(vocab: Mapping) -> InverseVocab:
def reverse_vocab(vocab: Vocab) -> InverseVocab:
"""
Returns value-to-key mapping from key-to-value-mapping.
Expand All @@ -239,6 +239,16 @@ def reverse_vocab(vocab: Mapping) -> InverseVocab:
return {v: k for k, v in vocab.items()}


def get_ordered_tokens_from_vocab(vocab: Vocab) -> List[str]:
"""
Returns the list of tokens in a vocabulary, ordered by increasing vocabulary id.
:param vocab: Input vocabulary.
:return: List of tokens.
"""
return [token for token, token_id in sorted(vocab.items(), key=lambda i: i[1])]


def are_identical(*vocabs: Vocab):
assert len(vocabs) > 0, "At least one vocabulary needed."
return all(set(vocab.items()) == set(vocabs[0].items()) for vocab in vocabs)
Expand Down
12 changes: 10 additions & 2 deletions test/unit/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# is located at
#
# http://aws.amazon.com/apache2.0/
#
#
# or in the "license" file accompanying this file. This file is distributed on
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
Expand All @@ -14,7 +14,7 @@
import pytest

import sockeye.constants as C
from sockeye.vocab import build_vocab
from sockeye.vocab import build_vocab, get_ordered_tokens_from_vocab

test_vocab = [
# Example 1
Expand All @@ -41,6 +41,7 @@ def test_build_vocab(data, size, min_count, expected):
vocab = build_vocab(data, size, min_count)
assert vocab == expected


test_constants = [
# Example 1
(["one two three", "one two three"], 3, 1, C.VOCAB_SYMBOLS),
Expand All @@ -59,3 +60,10 @@ def test_constants_in_vocab(data, size, min_count, constants):
vocab = build_vocab(data, size, min_count)
for const in constants:
assert const in vocab


@pytest.mark.parametrize("vocab, expected_output", [({"<pad>": 0, "a": 4, "b": 2}, ["<pad>", "b", "a"]),
({}, [])])
def test_get_ordered_tokens_from_vocab(vocab, expected_output):
assert get_ordered_tokens_from_vocab(vocab) == expected_output

0 comments on commit cce1acc

Please sign in to comment.