diff --git a/.dockerignore b/.dockerignore index bf6d83d..2de1d56 100644 --- a/.dockerignore +++ b/.dockerignore @@ -5,10 +5,11 @@ Dockerfile old-results/ .git/ .gitignore +.ipynb_checkpoints/ .mypy_cache .pytest_cache/ -results -scripts +results/ +scripts/ .travis.yml .venv .vscode diff --git a/.gitignore b/.gitignore index a691629..2e04592 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ *.pyc __pycache__ .venv/ +.ipynb_checkpoints # Testing @@ -14,5 +15,9 @@ __pycache__ .vscode/ -# AlleNLP +# Experiment-related results/ +data/ +*.yaml +*.npy +*.jinja2 diff --git a/Dockerfile b/Dockerfile index 6205775..dbc7ddc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,33 +1,28 @@ -FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04 -RUN echo "deb-src http://archive.ubuntu.com/ubuntu/ xenial main" | tee -a /etc/apt/sources.list -RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential \ - cmake \ - git \ - curl \ - vim \ - ca-certificates \ - libjpeg-dev \ - libpng-dev &&\ - rm -rf /var/lib/apy/lists/* +# FROM python:3.6.8-jessie +FROM pytorch/pytorch:1.1.0-cuda10.0-cudnn7.5-runtime +ENV LC_ALL=C.UTF-8 +ENV LANG=C.UTF-8 -RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ - chmod +x ~/miniconda.sh && \ - ~/miniconda.sh -b -p /opt/conda && \ - rm ~/miniconda.sh && \ - /opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include cython typing && \ - /opt/conda/bin/conda install -y -c pytorch magma-cuda100 && \ - /opt/conda/bin/conda clean -ya -ENV PATH /opt/conda/bin:$PATH +ENV PATH /usr/local/nvidia/bin/:$PATH +ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 + +# Tell nvidia-docker the driver spec that we need as well as to +# use all available devices, which are mounted at /usr/local/nvidia. +# The LABEL supports an older version of nvidia-docker, the env +# variables a newer one. +ENV NVIDIA_VISIBLE_DEVICES all +ENV NVIDIA_DRIVER_CAPABILITIES compute,utility +LABEL com.nvidia.volumes.needed="nvidia_driver" WORKDIR /workspace +RUN chmod -R a+w /workspace + +COPY requirements.txt . +RUN pip install --upgrade pip +RUN pip install -r requirements.txt -COPY experiments/ experiments/ -COPY kglm/ kglm/ COPY .pylintrc .pylintrc COPY pytest.ini pytest.ini COPY README.md README.md -COPY requirements.txt . - -RUN pip install -r requirements.txt -RUN chmod -R a+w /workspace +COPY kglm/ kglm/ +COPY experiments/ experiments/ diff --git a/experiments/conll_2012_vocab.jsonnet b/experiments/conll_2012_vocab.jsonnet new file mode 100644 index 0000000..e1c9738 --- /dev/null +++ b/experiments/conll_2012_vocab.jsonnet @@ -0,0 +1,17 @@ +{ + "vocabulary": { + "type": "extended", + "max_vocab_size": {"tokens": 10000} + }, + "datasets_for_vocab_creation": ["train"], + "dataset_reader": { + "type": "conll2012_jsonl", + "token_indexers": { + "tokens": { + "type": "single_id", + "lowercase_tokens": false + } + } + }, + "train_data_path": "data/conll-2012/processed/train.jsonl", +} diff --git a/experiments/entity_disc.jsonnet b/experiments/entity_disc.jsonnet index 75c6f73..4b9e9e7 100644 --- a/experiments/entity_disc.jsonnet +++ b/experiments/entity_disc.jsonnet @@ -37,7 +37,7 @@ "trainer": { "type": "lm", "cuda_device": 0, - "num_epochs": 750, + "num_epochs": 13, "optimizer": { "type": "adam", "lr": 0.0003 @@ -48,4 +48,4 @@ "directory_path": "data/enhanced-wikitext-2/vocab", "extend": false } -} \ No newline at end of file +} diff --git a/experiments/entity_disc_conll2012.jsonnet b/experiments/entity_disc_conll2012.jsonnet index 2cee49d..b5fef64 100644 --- a/experiments/entity_disc_conll2012.jsonnet +++ b/experiments/entity_disc_conll2012.jsonnet @@ -6,12 +6,6 @@ }, "dataset_reader": { "type": "conll2012_jsonl", - "token_indexers": { - "tokens": { - "type": "single_id", - "lowercase_tokens": true - } - } }, "train_data_path": "data/conll-2012/processed/train.jsonl", "validation_data_path": "data/conll-2012/processed/dev.jsonl", diff --git a/experiments/entity_disc_conll2012_no_peeking.jsonnet b/experiments/entity_disc_conll2012_no_peeking.jsonnet index 5b9b5bf..ecbf5e1 100644 --- a/experiments/entity_disc_conll2012_no_peeking.jsonnet +++ b/experiments/entity_disc_conll2012_no_peeking.jsonnet @@ -6,13 +6,7 @@ }, "dataset_reader": { "type": "conll2012_jsonl", - "offset": 1, - "token_indexers": { - "tokens": { - "type": "single_id", - "lowercase_tokens": true - } - } + "offset": 1 }, "train_data_path": "data/conll-2012/processed/train.jsonl", "validation_data_path": "data/conll-2012/processed/dev.jsonl", @@ -22,13 +16,13 @@ "token_embedders": { "tokens": { "type": "embedding", - "embedding_dim": 128, + "embedding_dim": 256, "trainable": true }, }, }, - "embedding_dim": 128, - "hidden_size": 128, + "embedding_dim": 256, + "hidden_size": 256, "num_layers": 1, "max_mention_length": 100, "max_embeddings": 100, @@ -37,8 +31,8 @@ }, "iterator": { "type": "fancy", - "batch_size": 16, - "split_size": 15, + "batch_size": 343, + "split_size": 30, "splitting_keys": [ "source", "entity_types", @@ -48,8 +42,8 @@ }, "validation_iterator": { "type": "fancy", - "batch_size": 16, - "split_size": 15, + "batch_size": 343, + "split_size": 128, "splitting_keys": [ "source", "entity_types", @@ -64,7 +58,7 @@ "cuda_device": 0, "optimizer": { "type": "adam", - "lr": 1e-4 + "lr": 1e-3 }, "validation_metric": "+eid_acc" } diff --git a/experiments/entity_disc_conll2012_prp.jsonnet b/experiments/entity_disc_conll2012_prp.jsonnet new file mode 100644 index 0000000..5aaaae5 --- /dev/null +++ b/experiments/entity_disc_conll2012_prp.jsonnet @@ -0,0 +1,64 @@ +{ + "vocabulary": { + "type": "extended", + "extend": false, + "directory_path": "/kermit/rlogan/entity-nlm/data/vocabulary" + }, + "dataset_reader": { + "type": "conll2012_jsonl", + }, + "train_data_path": "/kermit/rlogan/entity-nlm/data/conll-2012/processed/train.jsonl", + "validation_data_path": "/kermit/rlogan/entity-nlm/data/conll-2012/processed/dev.jsonl", + "model": { + "type": "entitydisc", + "text_field_embedder": { + "token_embedders": { + "tokens": { + "type": "embedding", + "embedding_dim": 256, + "trainable": true + }, + }, + }, + "embedding_dim": 256, + "hidden_size": 256, + "num_layers": 1, + "max_mention_length": 100, + "max_embeddings": 100, + "dropout_rate": 0.4, + "variational_dropout_rate": 0.1 + }, + "iterator": { + "type": "fancy", + "batch_size": 343, + "split_size": 30, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + }, + "validation_iterator": { + "type": "fancy", + "batch_size": 343, + "split_size": 128, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + "truncate": false + }, + "trainer": { + "type": "lm", + "num_epochs": 400, + "cuda_device": 0, + "optimizer": { + "type": "adam", + "lr": 1e-3 + }, + "validation_metric": "+eid_acc" + } +} diff --git a/experiments/entity_nlm.jsonnet b/experiments/entity_nlm.jsonnet index 9bfc403..559e9ab 100644 --- a/experiments/entity_nlm.jsonnet +++ b/experiments/entity_nlm.jsonnet @@ -53,7 +53,7 @@ "num_epochs": 750, "optimizer": { "type": "adam", - "lr": 0.0003 + "lr": 0.0001 } }, "vocabulary": { diff --git a/experiments/entity_nlm_conll2012.jsonnet b/experiments/entity_nlm_conll2012.jsonnet index 609151b..07b1c56 100644 --- a/experiments/entity_nlm_conll2012.jsonnet +++ b/experiments/entity_nlm_conll2012.jsonnet @@ -1,17 +1,11 @@ { "vocabulary": { "type": "extended", - "extend": false, - "directory_path": "data/vocabulary" + "directory_path": "data/vocabulary", + "extend": false }, "dataset_reader": { "type": "conll2012_jsonl", - "token_indexers": { - "tokens": { - "type": "single_id", - "lowercase_tokens": true - } - } }, "train_data_path": "data/conll-2012/processed/train.jsonl", "validation_data_path": "data/conll-2012/processed/dev.jsonl", @@ -22,24 +16,24 @@ "token_embedders": { "tokens": { "type": "embedding", - "embedding_dim": 256, + "embedding_dim": 300, "trainable": true }, }, }, - "embedding_dim": 256, - "hidden_size": 256, + "embedding_dim": 300, + "hidden_size": 300, "num_layers": 1, "max_mention_length": 100, "max_embeddings": 100, - "tie_weights": true, - "dropout_rate": 0.4, - "variational_dropout_rate": 0.1 + "tie_weights": false, + "dropout_rate": 0.1, + "variational_dropout_rate": 0.2 }, "iterator": { "type": "fancy", - "batch_size": 512, - "split_size": 15, + "batch_size": 256, + "split_size": 120, "splitting_keys": [ "source", "entity_types", @@ -49,8 +43,8 @@ }, "validation_iterator": { "type": "fancy", - "batch_size": 512, - "split_size": 15, + "batch_size": 343, + "split_size": 128, "splitting_keys": [ "source", "entity_types", @@ -61,11 +55,11 @@ }, "trainer": { "type": "lm", - "num_epochs": 40, + "num_epochs": 400, "cuda_device": 0, "optimizer": { "type": "adam", - "lr": 1e-3 + "lr": 1e-3, } } } diff --git a/experiments/entity_nlm_conll2012_prp.jsonnet b/experiments/entity_nlm_conll2012_prp.jsonnet new file mode 100644 index 0000000..f49ab2c --- /dev/null +++ b/experiments/entity_nlm_conll2012_prp.jsonnet @@ -0,0 +1,65 @@ +{ + "vocabulary": { + "type": "extended", + "directory_path": "/kermit/rlogan/entity-nlm/data/vocabulary", + "extend": false + }, + "dataset_reader": { + "type": "conll2012_jsonl", + }, + "train_data_path": "/kermit/rlogan/entity-nlm/data/conll-2012/processed/train.jsonl", + "validation_data_path": "/kermit/rlogan/entity-nlm/data/conll-2012/processed/dev.jsonl", + "datasets_for_vocab_creation": ["train"], + "model": { + "type": "entitynlm", + "text_field_embedder": { + "token_embedders": { + "tokens": { + "type": "embedding", + "embedding_dim": 200, + "trainable": true + }, + }, + }, + "embedding_dim": 200, + "hidden_size": 200, + "num_layers": 1, + "max_mention_length": 100, + "max_embeddings": 100, + "tie_weights": false, + "dropout_rate": 0.2, + "variational_dropout_rate": 0.2 + }, + "iterator": { + "type": "fancy", + "batch_size": 60, + "split_size": 70, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + }, + "validation_iterator": { + "type": "fancy", + "batch_size": 60, + "split_size": 70, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + "truncate": false + }, + "trainer": { + "type": "lm", + "num_epochs": 400, + "cuda_device": 0, + "optimizer": { + "type": "adam", + "lr": 3e-4, + } + } +} diff --git a/experiments/kglm-copy.jsonnet b/experiments/kglm-copy.jsonnet new file mode 100644 index 0000000..490234c --- /dev/null +++ b/experiments/kglm-copy.jsonnet @@ -0,0 +1,116 @@ +{ + "vocabulary": { + "type": "extended", + "extend": false, + "directory_path": "data/linked-wikitext-2/vocab" + }, + "dataset_reader": { + "type": "enhanced-wikitext-kglm", + "alias_database_path": "data/linked-wikitext-2/alias.pkl" + }, + "train_data_path": "data/linked-wikitext-2/train.jsonl", + "validation_data_path": "data/linked-wikitext-2/valid.jsonl", + "model": { + "type": "kglm", + "token_embedder": { + "token_embedders": { + "tokens": { + "type": "embedding", + "embedding_dim": 400, + "trainable": true + } + } + }, + "entity_embedder": { + "token_embedders": { + "entity_ids": { + "type": "embedding", + "pretrained_file": "data/linked-wikitext-2/embeddings.entities.txt", + "embedding_dim": 256, + "trainable": false, + "vocab_namespace": "entity_ids" + } + } + }, + "relation_embedder": { + "token_embedders": { + "relations": { + "type": "embedding", + "pretrained_file": "data/linked-wikitext-2/embeddings.relations.txt", + "embedding_dim": 256, + "trainable": true, + "vocab_namespace": "relations" + } + } + }, + "alias_encoder": { + "type": "lstm", + "input_size": 400, + "hidden_size": 400 + }, + "knowledge_graph_path": "data/linked-wikitext-2/knowledge_graph.pkl", + "use_shortlist": false, + "hidden_size": 1150, + "num_layers": 3, + "cutoff": 30, + "tie_weights": true, + "initializer": [ + ["token_embedder.weight", {"type": "uniform", "a": -0.1, "b": 0.1}], + ["decoder.bias", {"type": "constant", "val": 0.0}] + ] + }, + "iterator": { + "type": "fancy", + "batch_size": 60, + "split_size": 70, + "splitting_keys": [ + "source", + "target", + "mention_type", + "raw_entity_ids", + "entity_ids", + "parent_ids", + "relations", + "shortlist_inds", + "alias_copy_inds" + ] + }, + "validation_iterator": { + "type": "fancy", + "batch_size": 60, + "split_size": 70, + "splitting_keys": [ + "source", + "target", + "mention_type", + "raw_entity_ids", + "entity_ids", + "parent_ids", + "relations", + "shortlist_inds", + "alias_copy_inds" + ], + "truncate": false + }, + "trainer": { + "type": "lm", + "num_epochs": 500, + "cuda_device": 0, + // "grad_clipping": 0.25, + // "optimizer": { + // "type": "nt-asgd", + // "lr": 22.5, + // "weight_decay": 1.2e-6 + // }, + // "learning_rate_scheduler": { + // "type": "nt-asgd", + // "non_monotone_interval": 5 + // }, + "optimizer": { + "type": "adam", + "lr": 3e-4, + "weight_decay": 1.2e-6 + }, + "validation_metric": "-ppl" + } +} diff --git a/kglm/commands/__init__.py b/kglm/commands/__init__.py index 2d5dd9c..d4e0ae9 100644 --- a/kglm/commands/__init__.py +++ b/kglm/commands/__init__.py @@ -1,2 +1,3 @@ from .evaluate_perplexity import EvaluatePerplexity from .complete_the_sentence import CompleteTheSentence +from .beamsum import BeamSum diff --git a/kglm/commands/beamsum.py b/kglm/commands/beamsum.py new file mode 100644 index 0000000..9b0bbeb --- /dev/null +++ b/kglm/commands/beamsum.py @@ -0,0 +1,195 @@ +import argparse +import json +import logging +import math +from typing import Any, Dict, Iterator + +from allennlp.commands.subcommand import Subcommand +from allennlp.common.util import prepare_environment +from allennlp.common.checks import check_for_gpu +from allennlp.common.tqdm import Tqdm +from allennlp.data import Instance +from allennlp.data.dataset_readers.dataset_reader import DatasetReader +from allennlp.data.iterators import BasicIterator, DataIterator +from allennlp.models import Model +from allennlp.models.archival import load_archive +from allennlp.nn import util +import numpy as np +import torch + +logger = logging.getLogger(__name__) + + +class BeamSum(Subcommand): + def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argparse.ArgumentParser: + # pylint: disable=protected-access + description = '''Upper bound the specified model perplexity using beam search''' + subparser = parser.add_parser(name, description=description, + help='Evaluate the specified module using importance sampling') + + subparser.add_argument('model_archive_file', type=str, help='path to an archived trained model') + + subparser.add_argument('sampler_archive_file', type=str, + help='path to an archived trained model for generating samples') + + subparser.add_argument('input_file', type=str, help='path to the file containing the evaluation data') + + subparser.add_argument('--output-file', type=str, help='path to output file') + + subparser.add_argument('--weights-file', + type=str, + help='a path that overrides which weights file to use') + + cuda_device = subparser.add_mutually_exclusive_group(required=False) + cuda_device.add_argument('--cuda-device', + type=int, + default=-1, + help='id of GPU to use (if any)') + + subparser.add_argument('-o', '--overrides', + type=str, + default="", + help='a JSON structure used to override the experiment configuration') + + subparser.add_argument('--batch-size', + type=int, + default=None, + help='Batch size (default: whatever iterator was set to)') + + subparser.add_argument('--split-size', + type=int, + default=None, + help='Split size (default: whatever iterator was set to)') + + subparser.add_argument('--beam-width', + type=int, + default=2, + help='Beam width') + + subparser.set_defaults(func=evaluate_from_args) + + return subparser + + +def evaluate_perplexity(model: Model, + sampler: Model, + instances: Iterator[Instance], + data_iterator: DataIterator, + cuda_device: int, + beam_width: int) -> Dict[str, Any]: + check_for_gpu(cuda_device) + + logger.info('Iterating over dataset') + + weight = None + + iterator = data_iterator(instances, num_epochs=1, shuffle=False) + generator_tqdm = Tqdm.tqdm(iterator, total=0) + + model.eval() + sampler.eval() + sampler._state = None + + summand = None + denom = None + #summand = torch.tensor(0.0) + # penalized_summand = torch.tensor(0.0) + + held_over_data = None + + for batch in generator_tqdm: + + # We need sequence length to help compute perplexity + batch_size, _ = batch['source']['tokens'].shape + n_tokens = util.get_text_field_mask(batch['source']).float().sum(dim=-1) + if denom is None: + denom = n_tokens.sum() + else: + denom += n_tokens.sum() + + summand = util.move_to_device(summand, cuda_device) + batch = util.move_to_device(batch, cuda_device) + + # Draw a sample + with torch.no_grad(): + # sample = sampler.beam_search(source=batch['source'], + # reset=batch['reset'], + # target=batch['target'], + # metadata=batch['metadata'], + # k=beam_width) + sample = sampler.beam_search(**batch, k=beam_width) + + # Evaluate on sample + with torch.no_grad(): + model_output = model(**sample) + # gold_output = model(**batch) + + model_logp = model_output['logp'] + # logger.debug(model_logp) + # logger.debug(gold_output['logp']) + model_logp = model_logp.view(batch_size, beam_width) + model_logp = torch.logsumexp(model_logp, -1) + + # logger.debug(torch.exp(-model_logp.sum() / n_tokens.sum())) + + if summand is None: + summand = model_logp.sum() + else: + summand += model_logp.sum() + + logger.debug(torch.exp(-summand / denom)) + + ppl = torch.exp(-summand / denom) + + metrics = { + 'ppl': ppl + } + return metrics + +def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]: + # Disable some of the more verbose logging statements + logging.getLogger('allennlp.common.params').disabled = True + logging.getLogger('allennlp.nn.initializers').disabled = True + logging.getLogger('allennlp.modules.token_embedders.embedding').setLevel(logging.INFO) + + # Load model from archive + model_archive = load_archive(args.model_archive_file, args.cuda_device, args.overrides, args.weights_file) + config = model_archive.config + prepare_environment(config) + model = model_archive.model + model.eval() + + # Load sampler + sampler_archive = load_archive(args.sampler_archive_file, args.cuda_device, args.overrides, args.weights_file) + sampler = sampler_archive.model + sampler.eval() + + # Load the evaluation data. NOTE: We are using the model's reader! + validation_dataset_reader_params = config.pop('validation_dataset_reader', None) + if validation_dataset_reader_params is not None: + dataset_reader = DatasetReader.from_params(validation_dataset_reader_params) + else: + dataset_reader = DatasetReader.from_params(config.pop('dataset_reader')) + evaluation_data_path = args.input_file + logger.info('Reading evaluation data from: %s', evaluation_data_path) + instances = dataset_reader.read(evaluation_data_path) + + # To avoid hairy issues with splitting, we opt to use a basic iterator so that we can + # generate samples for entire sequences. + iterator_params = config.pop('iterator', 'None') + if args.batch_size is not None: + iterator_params['batch_size'] = args.batch_size + if args.split_size is not None: + iterator_params['split_size'] = args.split_size + iterator_params['truncate'] = False + iterator = DataIterator.from_params(iterator_params) + iterator.index_with(model.vocab) + metrics = evaluate_perplexity(model, sampler, instances, iterator, + args.cuda_device, args.beam_width) + + logger.info('Finished evaluating.') + logger.info('Metrics:') + for key, metric in metrics.items(): + logger.info('%s: %s', key, metric) + + return metrics diff --git a/kglm/commands/evaluate_perplexity.py b/kglm/commands/evaluate_perplexity.py index 21dfeeb..2300776 100644 --- a/kglm/commands/evaluate_perplexity.py +++ b/kglm/commands/evaluate_perplexity.py @@ -34,7 +34,7 @@ def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argpar subparser.add_argument('input_file', type=str, help='path to the file containing the evaluation data') - subparser.add_argument('--output-file', type=str, help='path to output file') + subparser.add_argument('--out', type=str, help='prefix of output files') subparser.add_argument('--weights-file', type=str, @@ -53,12 +53,12 @@ def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argpar subparser.add_argument('--batch-size', type=int, - default=None, + default=1, help='Batch size (default: whatever iterator was set to)') subparser.add_argument('--split-size', type=int, - default=None, + default=int(1e10), help='Split size (default: whatever iterator was set to)') subparser.add_argument('--num-samples', @@ -83,24 +83,50 @@ def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argpar return subparser -PRESERVED_FIELDS = {'source', 'reset'} +# TODO: Make sure this still makes sense... +# PRESERVED_FIELDS = {'source', 'reset'} +# +# +# def _offset(sample, held_over_data): +# batch_size = sample['reset'].size(0) +# new_sample = {'source': sample['source'], +# 'reset': sample['reset']} +# new_held_over_data = {} +# for field in sample: +# if field in PRESERVED_FIELDS: +# continue +# if held_over_data is None: +# prefix = sample[field].new_zeros(batch_size) +# else: +# prefix = held_over_data[field] +# new_sample[field] = torch.cat((prefix.unsqueeze(1), sample[field][:,:-1]), dim=1) +# new_held_over_data[field] = sample[field][:,-1] +# return new_sample, new_held_over_data + +UNSPLIT_FIELDS = {'reset', 'metadata', 'shortlist'} +def split(batch, split_size: int): + sequence_length = batch['source']['tokens'].shape[1] + num_splits = sequence_length // split_size + if not ((sequence_length % split_size) == 0): + num_splits += 1 + else: + logger.warning('Perfect split') + + def _chunk(x, start, stop): + if isinstance(x, dict): + return {k: v if k in UNSPLIT_FIELDS else _chunk(v, start, stop) for k, v in x.items()} + if isinstance(x, torch.Tensor): + return x[:, start:stop].contiguous() + chunks = [] + for i in range(num_splits): + chunk = _chunk(batch, i * split_size, (i + 1) * split_size) -def _offset(sample, held_over_data): - batch_size = sample['reset'].size(0) - new_sample = {'source': sample['source'], - 'reset': sample['reset']} - new_held_over_data = {} - for field in sample: - if field in PRESERVED_FIELDS: - continue - if held_over_data is None: - prefix = sample[field].new_zeros(batch_size) - else: - prefix = held_over_data[field] - new_sample[field] = torch.cat((prefix.unsqueeze(1), sample[field][:,:-1]), dim=1) - new_held_over_data[field] = sample[field][:,-1] - return new_sample, new_held_over_data + if i > 0: + chunk['reset'] = torch.zeros_like(chunk['reset']) + + chunks.append(chunk) + return chunks def tile(t, amount): @@ -110,19 +136,21 @@ def tile(t, amount): return t.repeat(*args) elif isinstance(t, dict): return {k: tile(v, amount) for k, v in t.items()} + elif isinstance(t, list): + return [x for x in t for _ in range(amount)] -def logsumexp(prev: torch.FloatTensor, - current: torch.FloatTensor, - i: int, - samples_per_batch: int): - # NOTE: n is number of samples - current_avg = current.view(samples_per_batch, -1).sum(dim=-1).logsumexp(dim=0) - np.log(samples_per_batch).item() - if prev is None: - return current_avg - a = torch.max(prev, current_avg) - sumexp = torch.exp(prev - a) * i / (i + 1) + torch.exp(current_avg - a) / (i + 1) - return a + torch.log(sumexp) +# def logsumexp(prev: torch.FloatTensor, +# current: torch.FloatTensor, +# i: int, +# samples_per_batch: int): +# # NOTE: n is number of samples +# current_avg = current.view(samples_per_batch, -1).sum(dim=-1).logsumexp(dim=0) - np.log(samples_per_batch).item() +# if prev is None: +# return current_avg +# a = torch.max(prev, current_avg) +# sumexp = torch.exp(prev - a) * i / (i + 1) + torch.exp(current_avg - a) / (i + 1) +# return a + torch.log(sumexp) def evaluate_perplexity(model: Model, @@ -133,116 +161,110 @@ def evaluate_perplexity(model: Model, cuda_device: int, temperature: float = 1.0, offset: bool = False, - samples_per_batch: int = 1) -> Dict[str, Any]: - check_for_gpu(cuda_device) + samples_per_batch: int = 1, + split_size: int = int(1e10)) -> Dict[str, Any]: + check_for_gpu(cuda_device) logger.info('Iterating over dataset') + # weight = None - # summands = [] - # penalized_summands = [] - trajectory = np.zeros(num_samples // samples_per_batch) - individual_estimates = np.zeros(num_samples // samples_per_batch) - s_probs = np.zeros((348, num_samples // samples_per_batch)) + model.eval() + sampler.eval() + iterator = data_iterator(instances, num_epochs=1, shuffle=False) + generator_tqdm = Tqdm.tqdm(iterator, total=0) - weight = None + summand = 0.0 + denom = 0.0 + fp = [] + q = [] + all_weights = [] - for i in range(num_samples // samples_per_batch): - iterator = data_iterator(instances, num_epochs=1, shuffle=False) - generator_tqdm = Tqdm.tqdm(iterator, total=0) + for batch in generator_tqdm: - model.eval() - sampler.eval() - sampler._state = None + batch_size = batch['reset'].shape[0] - summand = None - denom = None - #summand = torch.tensor(0.0) - # penalized_summand = torch.tensor(0.0) + n_tokens = util.get_text_field_mask(batch['source']).float().sum() + denom += n_tokens - held_over_data = None + epoch_weights = [] + epoch_fp = [] + epoch_q = [] - for batch, _ in generator_tqdm: + batch = util.move_to_device(batch, cuda_device) - # We need sequence length to help compute perplexity - n_tokens = util.get_text_field_mask(batch['source']).float().sum(dim=-1) - if denom is None: - denom = n_tokens - else: - denom += n_tokens + # Tile if that's what we're doing + if samples_per_batch > 1: + batch = tile(batch, samples_per_batch) - summand = util.move_to_device(summand, cuda_device) - batch = util.move_to_device(batch, cuda_device) + for i in range(num_samples // samples_per_batch): - # Tile if that's what we're doing - if samples_per_batch > 1: - batch = tile(batch, samples_per_batch) + # summand = util.move_to_device(summand, cuda_device) + # batch = util.move_to_device(batch, cuda_device) - # Draw a sample - with torch.no_grad(): - sampler_output = sampler.sample(**batch, - temperature=temperature, - offset=offset) - sample_logp = sampler_output['logp'] - sample = sampler_output['sample'] + weights = None + for j, chunk in enumerate(split(batch, split_size)): + generator_tqdm.set_description(f"i={i} j={j}") - if offset: - sample, held_over_data = _offset(sample, held_over_data) + chunk_tokens = util.get_text_field_mask(batch['source']).int().sum() + if chunk_tokens == 0: + logger.debug('Zero chunk, skipping') + continue - # Evaluate on sample - with torch.no_grad(): - model_output = model(**sample) + # Draw a sample + with torch.no_grad(): + sampler_output = sampler.sample(**chunk, + temperature=temperature, + offset=offset) + sample_logp = sampler_output['logp'] + sample = sampler_output['sample'] - model_logp = model_output['logp'] - if summand is None: - summand = (model_logp - sample_logp) - else: - summand += (model_logp - sample_logp) + # if offset: + # sample, held_over_data = _offset(sample, held_over_data) - # model_penalized_logp = model_output['penalized_logp'] - # penalized_summand += (model_penalized_logp - sample_logp) + with torch.no_grad(): + model_output = model(**sample) - # generator_tqdm.set_description('Instantaneous PPL: %0.4f' % torch.exp((sample_logp - model_logp) / n_tokens).item()) + model_logp = model_output['logp'] + split_weights = (model_logp - sample_logp).view(batch_size, samples_per_batch) + if weights is None: + weights = split_weights + else: + weights += split_weights + # logger.debug(torch.exp(-split_weights/split_size)) - current_avg = summand.view(samples_per_batch, -1).sum(dim=-1).logsumexp(dim=0) - np.log(samples_per_batch).item() - instance_ppl = torch.exp(-current_avg.sum() / denom.sum()) + epoch_weights.append(weights) #.cpu()) + epoch_fp.append(model_logp.view(batch_size, samples_per_batch))# .cpu()) + epoch_q.append(sample_logp.view(batch_size, samples_per_batch))# .cpu()) - weight = logsumexp(weight, summand, i, samples_per_batch) - ppl = torch.exp(-weight / denom.sum()) + # Combine all the epoch weights + combined_weights = torch.cat(epoch_weights, dim=1) + combined_fp = torch.cat(epoch_fp, dim=1) + combined_q = torch.cat(epoch_q, dim=1) + all_weights.append(combined_weights) + fp.append(combined_fp) + q.append(combined_q) - individual_estimates[i] = instance_ppl.item() - trajectory[i] = ppl.item() + # Compute importance sampled logp of the sequences in the batch + logp_hat = combined_weights.logsumexp(dim=1) - math.log(samples_per_batch) + summand += logp_hat.sum() - s_probs[:, i] = torch.exp(-summand.cpu() / denom.cpu()).numpy() - # summands.append(summand) - # # penalized_summands.append(penalized_summand) - # # if i == 0: - # # t = summand.unsqueeze(0) - # # p = penalized_summand.unsqueeze(0) - # # else: - # # t = torch.stack(summands, dim=0) - # # # p = torch.stack(penalized_summands, dim=0) - # t = torch.cat(summands, dim=0) - # t_sum = torch.logsumexp(t, dim=0) - # # p_sum = torch.logsumexp(p, dim=0) - # sum_logp = (t_sum - math.log((i+1)*1000)).item() - # # sum_logp_penalized = (p_sum - math.log((i+1)*1000)).item() - # ppl = math.exp(-sum_logp / 659) - # # upp = math.exp(-sum_logp_penalized / denom) + logger.info(f'PPL: {torch.exp(-summand / denom)}') - # trajectory[i] = ppl - # # individual_estimates[i] = math.exp(-summand.item() / denom) + # Create array of all the weights + all_weights_array = torch.cat(all_weights, dim=0).cpu().numpy() + fp_array = torch.cat(fp, dim=0).cpu().numpy() + q_array = torch.cat(q, dim=0).cpu().numpy() - # print('PPL: %f' % ppl) - # # print('UPP: %f' % upp) + # Compute perplexity + ppl = torch.exp(-summand / denom) metrics = { 'ppl': ppl, - # 'upp': upp, - 'trajectory': trajectory, - 'individual_estimates': individual_estimates, - 's_probs': s_probs + 'weights': all_weights_array, + 'fp': fp_array, + 'q': q_array } return metrics @@ -251,6 +273,7 @@ def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]: logging.getLogger('allennlp.common.params').disabled = True logging.getLogger('allennlp.nn.initializers').disabled = True logging.getLogger('allennlp.modules.token_embedders.embedding').setLevel(logging.INFO) + logger.warning('This code will return improper results if sequences are split') # Load model from archive model_archive = load_archive(args.model_archive_file, args.cuda_device, args.overrides, args.weights_file) @@ -274,29 +297,27 @@ def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]: logger.info('Reading evaluation data from: %s', evaluation_data_path) instances = dataset_reader.read(evaluation_data_path) - # To avoid hairy issues with splitting, we opt to use a basic iterator so that we can - # generate samples for entire sequences. iterator_params = config.pop('iterator', 'None') - if args.batch_size is not None: - iterator_params['batch_size'] = args.batch_size - if args.split_size is not None: - iterator_params['split_size'] = args.split_size - iterator_params['truncate'] = False + iterator_params['batch_size'] = args.batch_size + # Make split size really large to prevent splits (otherwise we'd have to + # deal with averaging the importance samples across splits ... + # if args.split_size is not None: + # iterator_params['split_size'] = args.split_size + iterator_params['split_size'] = int(1e10) + iterator_params['truncate'] = False # TODO: Shouldn't need this anymore... iterator = DataIterator.from_params(iterator_params) iterator.index_with(model.vocab) metrics = evaluate_perplexity(model, sampler, args.num_samples, instances, iterator, args.cuda_device, args.temperature, - args.offset, args.samples_per_batch) + args.offset, args.samples_per_batch, + args.split_size) logger.info('Finished evaluating.') - logger.info('Metrics:') - for key, metric in metrics.items(): - logger.info('%s: %s', key, metric) - - output_file = args.output_file - if output_file: - np.save(output_file + '.trajectory.npy', metrics['trajectory']) - np.save(output_file + '.individual_estimates.npy', metrics['individual_estimates']) - np.save(output_file + '.s_probs.npy', metrics['s_probs']) + + if args.out: + np.save(args.out + '_weights.npy', metrics['weights']) + np.save(args.out + '_fp.npy', metrics['fp']) + np.save(args.out + '_q.npy', metrics['q']) + return metrics diff --git a/kglm/data/__init__.py b/kglm/data/__init__.py index 3a4cecd..b17416b 100644 --- a/kglm/data/__init__.py +++ b/kglm/data/__init__.py @@ -2,3 +2,4 @@ from .fields import SequentialArrayField from .iterators import SplitIterator from .extended_vocabulary import ExtendedVocabulary +from .dataset_readers import * diff --git a/kglm/data/alias_database.py b/kglm/data/alias_database.py index 1200469..25427de 100644 --- a/kglm/data/alias_database.py +++ b/kglm/data/alias_database.py @@ -90,6 +90,14 @@ def load(cls, path: str): id_array_lookup=id_array_lookup, token_to_entity_lookup=token_to_entity_lookup) + def nested_token_to_uid(self, e, t): + if isinstance(e, list) and isinstance(t, list): + return [self.nested_token_to_uid(_e, _t) for _e, _t in zip(e, t)] + elif isinstance(e, str) and isinstance(t, str): + return self.token_to_uid(e, t) + else: + raise ValueError(f'Encountered error looking up copy indices:\ne:{e}\nt:{t}') + def token_to_uid(self, entity: str, token: str) -> int: if entity in self._id_map_lookup: id_map = self._id_map_lookup[entity] diff --git a/kglm/data/dataset_readers/conll2012.py b/kglm/data/dataset_readers/conll2012.py index f5e58e7..a663b1a 100644 --- a/kglm/data/dataset_readers/conll2012.py +++ b/kglm/data/dataset_readers/conll2012.py @@ -188,7 +188,7 @@ def text_to_instance(self, # type: ignore # Initialize fields. entity_types = np.zeros(shape=(len(tokens),)) entity_ids = np.zeros(shape=(len(tokens),)) - mention_lengths = np.ones(shape=(len(tokens),)) + mention_lengths = np.zeros(shape=(len(tokens),)) if cluster_dict: for cluster, entity_id in cluster_dict.items(): @@ -199,7 +199,7 @@ def text_to_instance(self, # type: ignore entity_ids[cluster[0] + 1:cluster[1] + 1 + 1] = entity_id entity_length = (cluster[1] + 1) - cluster[0] # Fill in mention length - mention_lengths[cluster[0] + 1:cluster[1] + 1 + 1] = np.arange(entity_length, 0, step=-1) + mention_lengths[cluster[0] + 1:cluster[1] + 1 + 1] = np.arange(entity_length, 0, step=-1) - 1 fields['entity_ids'] = SequentialArrayField(entity_ids, dtype=np.int64) fields['mention_lengths'] = SequentialArrayField(mention_lengths, dtype=np.int64) @@ -233,20 +233,18 @@ def text_to_instance(self, tokens: List[str], clusters: Dict[str, List[Tuple[int, int]]]) -> Instance: # pylint: disable=arguments-differ - tokens = [_normalize_word(x, self._replace_numbers) for x in tokens] - tokens = ['@@START@@', *tokens, '@@END@@'] fields = {'source': TextField([Token(x) for x in tokens], self._token_indexers)} entity_types = np.zeros(shape=(len(tokens),)) entity_ids = np.zeros(shape=(len(tokens),)) - mention_lengths = np.ones(shape=(len(tokens),)) + mention_lengths = np.zeros(shape=(len(tokens),)) for i, cluster in enumerate(clusters.values()): for span in cluster: start, end = span - entity_types[(start + 1 - self._offset):(end + 1 - self._offset)] = 1 - entity_ids[(start + 1 - self._offset):(end + 1 - self._offset)] = i + 1 - mention_lengths[(start + 1 - self._offset):(end + 1 - self._offset)] = np.arange(end - start, 0, step=-1) + entity_types[(start - self._offset):(end - self._offset)] = 1 + entity_ids[(start - self._offset):(end - self._offset)] = i + 1 + mention_lengths[(start - self._offset):(end - self._offset)] = np.arange(end - start, 0, step=-1) - 1 fields['entity_types'] = SequentialArrayField(entity_types, dtype=np.uint8) fields['entity_ids'] = SequentialArrayField(entity_ids, dtype=np.int64) diff --git a/kglm/data/dataset_readers/enhanced_wikitext.py b/kglm/data/dataset_readers/enhanced_wikitext.py index 3577d14..d646e47 100644 --- a/kglm/data/dataset_readers/enhanced_wikitext.py +++ b/kglm/data/dataset_readers/enhanced_wikitext.py @@ -91,7 +91,7 @@ def text_to_instance(self, data: Dict[str, Any]) -> Instance: # pylint: disable seen_entities: Set[str] = set() entity_types = np.zeros(shape=(len(tokens),)) entity_ids = np.zeros(shape=(len(tokens),)) - mention_lengths = np.ones(shape=(len(tokens),)) + mention_lengths = np.zeros(shape=(len(tokens),)) # Process annotations for annotation in data['annotations']: @@ -105,7 +105,7 @@ def text_to_instance(self, data: Dict[str, Any]) -> Instance: # pylint: disable # Note: +1 offset to account for start token. entity_types[i] = 1 entity_ids[i] = len(seen_entities) - mention_lengths[i] = length + mention_lengths[i] = length - 1 length -= 1 fields['entity_types'] = SequentialArrayField(entity_types, dtype=np.uint8) @@ -362,7 +362,6 @@ def _read(self, file_path: str) -> Iterable[Instance]: def text_to_instance(self, data: Dict[str, Any]) -> Instance: # pylint: disable=arguments-differ # Flatten and pad tokens tokens = _flatten(data['tokens']) - tokens = ['@@START@@', *tokens, '@@END@@'] source = [Token(x) for x in tokens[:-1]] target = [Token(x) for x in tokens[1:]] fields = { diff --git a/kglm/data/iterators/fancy_iterator.py b/kglm/data/iterators/fancy_iterator.py index 2c6a5df..1186eae 100644 --- a/kglm/data/iterators/fancy_iterator.py +++ b/kglm/data/iterators/fancy_iterator.py @@ -105,7 +105,7 @@ def __call__(self, batch.index_instances(self.vocab) padding_lengths = batch.get_padding_lengths() - yield batch.as_tensor_dict(padding_lengths), 1 + yield batch.as_tensor_dict(padding_lengths) self._epochs[key] = epoch + 1 diff --git a/kglm/models/__init__.py b/kglm/models/__init__.py index e69de29..e6ff16c 100644 --- a/kglm/models/__init__.py +++ b/kglm/models/__init__.py @@ -0,0 +1,2 @@ +from .entity_disc import EntityNLMDiscriminator +from .entity_nlm import EntityNLM diff --git a/kglm/models/entity_disc.py b/kglm/models/entity_disc.py index 692becd..3893119 100644 --- a/kglm/models/entity_disc.py +++ b/kglm/models/entity_disc.py @@ -2,7 +2,7 @@ Discriminative version of EntityNLM for importance sampling. """ import logging -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from allennlp.nn.util import get_text_field_mask from allennlp.data.vocabulary import Vocabulary @@ -17,7 +17,7 @@ import torch import torch.nn.functional as F -from kglm.modules import DynamicEmbedding, WeightDrop +from kglm.modules import DynamicEmbedding, WeightDroppedLstm from kglm.nn.util import sample_from_logp logger = logging.getLogger(__name__) @@ -76,22 +76,10 @@ def __init__(self, self._max_embeddings = max_embeddings self._sos_token = self.vocab.get_token_index('@@START@@', 'tokens') self._eos_token = self.vocab.get_token_index('@@END@@', 'tokens') - - # Rnn Encoders. - rnns: List[torch.nn.Module] = [] - for i in range(num_layers): - if i == 0: - input_size = embedding_dim - else: - input_size = hidden_size - if (i == num_layers - 1): - output_size = embedding_dim - else: - output_size = hidden_size - rnns.append(torch.nn.LSTM(input_size, output_size, batch_first=True)) - rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=variational_dropout_rate) for rnn in rnns] - self.rnns = torch.nn.ModuleList(rnns) - + self._rnn = WeightDroppedLstm(num_layers=num_layers, + input_embedding_dim=embedding_dim, + hidden_size=hidden_size, + dropout=variational_dropout_rate) self._state: Optional[StateDict] = None # Input variational dropout @@ -103,12 +91,14 @@ def __init__(self, out_features=2, bias=False) self._dynamic_embeddings = DynamicEmbedding(embedding_dim=embedding_dim, - max_embeddings=max_embeddings) + max_embeddings=max_embeddings, + tied_weight=self._entity_type_projection.weight) # For mention length prediction self._mention_length_projection = torch.nn.Linear(in_features=2*embedding_dim, out_features=max_mention_length) + # Metrics self._entity_type_accuracy = CategoricalAccuracy() self._entity_id_accuracy = CategoricalAccuracy() self._mention_length_accuracy = CategoricalAccuracy() @@ -167,7 +157,7 @@ def forward(self, # pylint: disable=arguments-differ return output_dict - def sample(self, + def sample(self, # pylint: disable=unused-argument source: Dict[str, torch.Tensor], reset: torch.ByteTensor = None, temperature: float = 1.0, @@ -208,9 +198,11 @@ def sample(self, self.reset_states(reset) if self._state is None: - prev_mention_lengths = source['tokens'].new_ones(batch_size) + prev_mention_lengths = source['tokens'].new_zeros(batch_size) + prev_t = source['tokens'].new_zeros(batch_size) else: prev_mention_lengths = self._state['prev_mention_lengths'] + prev_t = self._state['prev_t'] # Embed tokens and get RNN hidden state. mask = get_text_field_mask(source) @@ -228,24 +220,7 @@ def sample(self, mask = mask.byte() & eos_mask embeddings = self._text_field_embedder(source) - current_input = embeddings - hidden_list = [] - for layer, rnn in enumerate(self.rnns): - # Retrieve previous hidden state for layer. - if self._state is not None: - prev_hidden = self._state['layer_%i' % layer] - else: - prev_hidden = None - # Forward-pass. - output, hidden = rnn(current_input, prev_hidden) - output = output.contiguous() - # Update hidden state for layer. - hidden = tuple(h.detach() for h in hidden) - hidden_list.append(hidden) - current_input = output - hidden = current_input - - self._state = {'layer_%i' % i: h for i, h in enumerate(hidden_list)} + hidden = self._rnn(embeddings) # Initialize outputs logp = hidden.new_zeros(batch_size) # Track total logp for **each** generated sample @@ -258,7 +233,7 @@ def sample(self, current_hidden = hidden[:, timestep] # We only predict types / ids / lengths if the previous mention is terminated. - predict_mask = prev_mention_lengths == 1 + predict_mask = prev_mention_lengths == 0 predict_mask = predict_mask & mask[:, timestep].byte() if predict_mask.any(): @@ -276,7 +251,7 @@ def sample(self, if predict_em.any(): # Predict entity ids entity_id_prediction_outputs = self._dynamic_embeddings(hidden=current_hidden, - timestep=timestep, + timestep=prev_t, mask=predict_em) entity_id_logits = entity_id_prediction_outputs['logits'] / temperature entity_id_mask = entity_id_prediction_outputs['logit_mask'] @@ -310,17 +285,16 @@ def sample(self, # Add / update entity embeddings new_entities = entity_ids[:, timestep] == self._dynamic_embeddings.num_embeddings self._dynamic_embeddings.add_embeddings(timestep, new_entities) - self._dynamic_embeddings.update_embeddings(hidden=current_hidden, update_indices=entity_ids[:, timestep], - timestep=timestep, + timestep=prev_t, mask=predict_em) # If the previous mentions are ongoing, we assign the output deterministically. Mention # lengths decrease by 1, all other outputs are copied from the previous timestep. Do # not need to add anything to logp since these 'predictions' have probability 1 under # the model. - deterministic_mask = prev_mention_lengths > 1 + deterministic_mask = prev_mention_lengths > 0 deterministic_mask = deterministic_mask & mask[:, timestep].byte() if deterministic_mask.any(): entity_types[deterministic_mask, timestep] = entity_types[deterministic_mask, timestep - 1] @@ -329,9 +303,11 @@ def sample(self, # Update mention lengths for next timestep prev_mention_lengths = mention_lengths[:, timestep] + prev_t += 1 # Update state - self._state['prev_mention_lengths'] = prev_mention_lengths.detach() + self._state = {'prev_mention_lengths': prev_mention_lengths.detach(), + 'prev_t': prev_t.detach()} return { 'logp': logp, @@ -344,6 +320,311 @@ def sample(self, } } + @property + def num_possible_annotations(self): + # Number of ways to annotate an entity mention + 1 way to annotate a non-entity mention. + return self._max_embeddings * self._max_mention_length + 1 + + @property + def entity_type_lookup(self): + entity_type_lookup = [0] + [1] * self._max_embeddings * self._max_mention_length + return torch.ByteTensor(entity_type_lookup) + + @property + def entity_id_lookup(self): + entity_id_lookup = [0] + [i for i in range(self._max_embeddings) for _ in range(self._max_mention_length)] + return torch.LongTensor(entity_id_lookup) + + @property + def mention_length_lookup(self): + mention_length_lookup = [0] + list(range(self._max_mention_length)) * self._max_embeddings + return torch.LongTensor(mention_length_lookup) + + def _annotation_logp(self, + hidden: torch.FloatTensor, + timestep: torch.LongTensor, + beam_states: List[Dict[str, Any]]) -> torch.Tensor: + """Computes the log-probability of all possible annotations for a single beam state. + + Parameters + ========== + TODO: Fill in + + Returns + ======= + A tensor of log-probabilities for the possible annotations of shape + (batch_size, sequence_length, num_annotations). + """ + batch_size, hidden_dim = hidden.shape + logp = hidden.new_zeros((batch_size, len(beam_states), self.num_possible_annotations)) + + for i, beam_state in enumerate(beam_states): + self._dynamic_embeddings.load_beam_state(beam_state) + + # Entity type log probabilities: (batch_size, 2) + entity_type_logits = self._entity_type_projection(hidden) + entity_type_logp = F.log_softmax(entity_type_logits, -1) + + # Entity id log probabilities: (batch_size, max_embeddings) + # Should be okay to use timestep instead of prev_t since there's no + # splitting in beam search. + entity_id_logits = self._dynamic_embeddings(hidden, timestep)['logits'] + entity_id_logp = F.log_softmax(entity_id_logits, -1) + + # Mention length log probabilites: (batch_size, max_embeddings x max_mention_lengths) + # NOTE: Entity id is guaranteed to be zero at initialization + embeddings = self._dynamic_embeddings.embeddings + concatenated = torch.cat((hidden.unsqueeze(1).expand_as(embeddings), embeddings), dim=-1) + mention_length_logits = self._mention_length_projection(concatenated) + mention_length_logp = F.log_softmax(mention_length_logits, -1).view(batch_size, -1) + + # Lastly, we need to tile entity id log probs properly. + entity_id_logp = entity_id_logp.unsqueeze(-1).repeat(1, 1, self._max_mention_length).view(batch_size, -1) + + logp[:, i, 0] += entity_type_logp[:, 0] + logp[:, i, 1:] += entity_type_logp[:, 1:] + logp[:, i, 1:] += entity_id_logp + logp[:, i ,1:] += mention_length_logp + + return logp + + def _adjust_for_ongoing_mentions(self, + logp: torch.FloatTensor, + output: Dict[str, torch.FloatTensor] = None) -> torch.FloatTensor: + """Fixes logp so that ongoing mentions are deterministic.""" + if output is None: + return logp + + mention_lengths = output['mention_lengths'] + entity_ids = output['entity_ids'] + + # Find ongoing mentions. + ongoing = mention_lengths > 0 + + # Make probability zero for all ongoing entries... + logp[ongoing] = -float('inf') + + # ...except the deterministic output + new_lengths = mention_lengths[ongoing] - 1 + entity_ids = entity_ids[ongoing] + annotation_idx = 1 + entity_ids * self._max_mention_length + new_lengths + logp[ongoing, annotation_idx] = 0 + + return logp + + def _top_k_annotations(self, + logp: torch.FloatTensor, + k: int): + """Extracts the top-k annotations. + + Parameters + ========== + logp : torch.Tensor + A (batch_size, beam_width, num_annotations tensor) + """ + batch_size = logp.shape[0] + # Get the top canidates from each beam + # (batch_size, beam_width, k) + top_logp, top_indices = logp.topk(k, dim=-1) + + # Next flatten + # (batch_size, beam_width * k) + flat_logp = top_logp.view(batch_size, -1) + flat_indices = top_indices.view(batch_size, -1) + + # Get the true top k + # (batch_size, k) + top_logp, top_indices = flat_logp.topk(k, dim=-1) + + # Retrieve backpointers from the indices + # (batch_size, k) + backpointers = top_indices // k + + # Also need to index correctly into the lookup + lookup_indices = flat_indices.gather(-1, top_indices) + + # Use lookup indices to get the top annotation variables + entity_types = self.entity_type_lookup.to(device=lookup_indices.device).take(lookup_indices) + entity_ids = self.entity_id_lookup.to(device=lookup_indices.device).take(lookup_indices) + mention_lengths = self.mention_length_lookup.to(device=lookup_indices.device).take(lookup_indices) + + output = { + 'logp': top_logp, + 'backpointers': backpointers, + 'entity_types': entity_types, + 'entity_ids': entity_ids, + 'mention_lengths': mention_lengths + } + return output + + def _update_beam_states(self, + hidden: torch.FloatTensor, + timestep: torch.LongTensor, + beam_states: List[Dict[str, Any]], + output: Dict[str, torch.Tensor]) -> List[Dict[str, Any]]: + """ + Given the new beam predictions we need to add/update entity embeddings. Before we can do + this we need to follow the backpointers to assemble correct tensors of the entity embeddings. + + """ + logp = output['logp'] + backpointers = output['backpointers'] + batch_size, k = logp.shape + + # Concat all the dynamic entity embeddings and trace backpointers to make sure the proper + # embeddings are loaded for each beam. + all_prev_entity_embeddings = logp.new_zeros(batch_size, len(beam_states), self._max_embeddings, self._embedding_dim) + all_prev_num_embeddings = backpointers.new_zeros(batch_size, len(beam_states)) + all_prev_last_seen = backpointers.new_zeros(batch_size, len(beam_states), self._max_embeddings) + for i, beam_state in enumerate(beam_states): + self._dynamic_embeddings.load_beam_state(beam_state) + all_prev_entity_embeddings[:, i] = self._dynamic_embeddings.embeddings + all_prev_num_embeddings[:, i] = self._dynamic_embeddings.num_embeddings + all_prev_last_seen[:, i] = self._dynamic_embeddings.last_seen + + new_beam_states: List[Dict[str, Any]] = [] + for i in range(k): + # Trace backpointers to get correct params + self._dynamic_embeddings.embeddings = all_prev_entity_embeddings[torch.arange(batch_size), backpointers[:, i]] + self._dynamic_embeddings.num_embeddings = all_prev_num_embeddings[torch.arange(batch_size), backpointers[:, i]] + self._dynamic_embeddings.last_seen = all_prev_last_seen[torch.arange(batch_size), backpointers[:, i]] + + # Add and update embeddings + entity_ids = output['entity_ids'][:, i] + entity_types = output['entity_types'][:, i] + new_entities = (entity_ids == 0) & entity_types + + # Gotta make the output handle new entities correctly + # TODO: Be a better programmer + entity_ids[new_entities] = self._dynamic_embeddings.num_embeddings[new_entities] + output['entity_ids'][:, i] = entity_ids + + # Now do this right... + # Should be okay to use timestep instead of prev_t since there's no + # splitting in beam search. + self._dynamic_embeddings.add_embeddings(timestep, new_entities) + self._dynamic_embeddings.update_embeddings(hidden=hidden, + update_indices=entity_ids, + timestep=timestep, + mask=entity_types) + + new_beam_states.append(self._dynamic_embeddings.beam_state()) + + return new_beam_states + + @staticmethod + def _trace_backpointers(source: Dict[str, torch.Tensor], + reset: torch.ByteTensor, + k: int, + predictions: List[Dict[str, torch.Tensor]]) -> Dict[str, Any]: + batch_size, seq_length = source['tokens'].shape + + new_reset = reset.unsqueeze(1).repeat(1, k).view(batch_size * k) + new_source = {key: value.unsqueeze(1).repeat(1, k, 1).view(batch_size * k, -1) for key, value in source.items()} + + entity_types = [] + entity_ids = [] + mention_lengths = [] + backpointer = None + + for prediction in reversed(predictions): + if backpointer is None: + entity_types.append(prediction['entity_types']) + entity_ids.append(prediction['entity_ids']) + mention_lengths.append(prediction['mention_lengths']) + else: + entity_types.append(prediction['entity_types'].gather(1, backpointer)) + entity_ids.append(prediction['entity_ids'].gather(1, backpointer)) + mention_lengths.append(prediction['mention_lengths'].gather(1, backpointer)) + if backpointer is None: + backpointer = prediction['backpointers'] + else: + backpointer = prediction['backpointers'].gather(1, backpointer) + + entity_types = torch.stack(entity_types[::-1], dim=-1).view(batch_size * k, -1) + entity_ids = torch.stack(entity_ids[::-1], dim=-1).view(batch_size * k, -1) + mention_lengths = torch.stack(mention_lengths[::-1], dim=-1).view(batch_size * k , -1) + + return { + 'reset': new_reset, + 'source': new_source, + 'entity_types': entity_types, + 'entity_ids': entity_ids, + 'mention_lengths': mention_lengths + } + + def beam_search(self, + source: Dict[str, torch.Tensor], + reset: torch.ByteTensor, + k: int, + **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Obtain the top-k (approximately) most likely predictions from the model using beam + search. Unlike typical beam search all of the beam states are returned instead of just + the most likely. + + The returned candidates are intended to be marginalized over to obtain an upper bound for + the token-level perplexity of the EntityNLM. + + Parameters + ========== + source : ``Dict[str, torch.Tensor]`` + A tensor of shape ``(batch_size, sequence_length)`` containing the sequence of + tokens. + reset : ``torch.ByteTensor`` + Whether or not to reset the model's state. This should be done at the start of each + new sequence. + k : ``int`` + Number of predictions to return. + + Returns + ======= + predictions : ``torch.Tensor`` + A tensor of shape ``(batch_size * k, sequence_length)`` containing the top-k + predictions. + logp : ``torch.Tensor`` + The log-probabilities of each prediction. WARNING: These are returned purely for + diagnostic purposes and should not be factored in the the perplexity calculation. + """ + batch_size, sequence_length = source['tokens'].shape + + # Reset the model's internal state. + if not reset.all(): + raise RuntimeError('Detecting that not all states are being `reset` (e.g., that input ' + 'sequences have been split). Cannot predict top-K annotations in ' + 'this setting!') + self.reset_states(reset) + prev_mention_lengths = source['tokens'].new_zeros(batch_size) + prev_t = source['tokens'].new_zeros(batch_size) + + # Embed and encode the tokens up front. + embeddings = self._text_field_embedder(source) + hidden = self._rnn(embeddings) + + # Beam search logic + predictions: List[Dict[str, torch.Tensor]] = [] + beam_states = [self._dynamic_embeddings.beam_state()] + output = None + for timestep in range(sequence_length): + # Get log probabilities of annotations + # (batch_size, k, num_annotations) + logp = self._annotation_logp(hidden[:, timestep], prev_t, beam_states) + # Accout for ongoing mentions + logp = self._adjust_for_ongoing_mentions(logp, output) + # Add to cumulative log probabilities of beams (which have shape (batch_size, k)) + if output: + logp += output['logp'].unsqueeze(-1) + + output = self._top_k_annotations(logp, k) + beam_states = self._update_beam_states(hidden[:, timestep], prev_t, beam_states, output) + predictions.append(output) + prev_t = prev_t + 1 + + # Trace backpointers to get annotation. + annotation = self._trace_backpointers(source, reset, k, predictions) + + return annotation + def _forward_loop(self, tokens: Dict[str, torch.Tensor], entity_types: torch.Tensor, @@ -383,32 +664,17 @@ def _forward_loop(self, # Need to track previous mention lengths in order to know when to measure loss. if self._state is None: - prev_mention_lengths = mention_lengths.new_ones(batch_size) + prev_mention_lengths = mention_lengths.new_zeros(batch_size) + prev_t = mention_lengths.new_zeros(batch_size) else: prev_mention_lengths = self._state['prev_mention_lengths'] + prev_t = self._state['prev_t'] # Embed tokens and get RNN hidden state. mask = get_text_field_mask(tokens) embeddings = self._text_field_embedder(tokens) embeddings = self._variational_dropout(embeddings) - current_input = embeddings - hidden_list = [] - for layer, rnn in enumerate(self.rnns): - # Retrieve previous hidden state for layer. - if self._state is not None: - prev_hidden = self._state['layer_%i' % layer] - else: - prev_hidden = None - # Forward-pass. - output, hidden = rnn(current_input, prev_hidden) - output = output.contiguous() - # Update hidden state for layer. - hidden = tuple(h.detach() for h in hidden) - hidden_list.append(hidden) - current_input = output - hidden = current_input - - self._state = {'layer_%i' % i: h for i, h in enumerate(hidden_list)} + hidden = self._rnn(embeddings) # Initialize losses entity_type_loss = torch.tensor(0.0, requires_grad=True, device=hidden.device) @@ -426,7 +692,7 @@ def _forward_loop(self, # We only predict types / ids / lengths if we are not currently in the process of # generating a mention (e.g. if the previous remaining mention length is 1). Indexing / # masking with ``predict_all`` makes it possible to do this in batch. - predict_all = prev_mention_lengths == 1 + predict_all = prev_mention_lengths == 0 predict_all = predict_all & mask[:, timestep].byte() if predict_all.any(): @@ -449,7 +715,7 @@ def _forward_loop(self, modified_entity_ids = current_entity_ids.clone() modified_entity_ids[modified_entity_ids == self._dynamic_embeddings.num_embeddings] = 0 entity_id_prediction_outputs = self._dynamic_embeddings(hidden=current_hidden, - timestep=timestep, + timestep=prev_t, target=modified_entity_ids, mask=predict_em) _entity_id_loss = -entity_id_prediction_outputs['loss'] @@ -478,10 +744,11 @@ def _forward_loop(self, # We also perform updates of the currently observed entities. self._dynamic_embeddings.update_embeddings(hidden=current_hidden, update_indices=current_entity_ids, - timestep=timestep, + timestep=prev_t, mask=current_entity_types) prev_mention_lengths = current_mention_lengths + prev_t += 1 # Normalize the losses entity_type_loss = entity_type_loss / mask.sum() @@ -498,9 +765,11 @@ def _forward_loop(self, 'loss': total_loss } - # Update state # Update the model state - self._state['prev_mention_lengths'] = mention_lengths[:, -1].detach() + self._state = { + 'prev_mention_lengths': mention_lengths[:, -1].detach(), + 'prev_t': prev_t.detach() + } return output_dict @@ -508,16 +777,12 @@ def reset_states(self, reset: torch.ByteTensor) -> None: """Resets the model's internals. Should be called at the start of a new batch.""" if reset.any() and (self._state is not None): # Zero out any previous elements - self._state['prev_mention_lengths'][reset] = 1 - # Zero out the hidden state - for layer in range(self._num_layers): - h, c = self._state['layer_%i' % layer] - h[:, reset, :] = torch.zeros_like(h[:, reset, :]) - c[:, reset, :] = torch.zeros_like(c[:, reset, :]) - self._state['layer_%i' % layer] = (h, c) - - # Reset the dynamic embeddings + self._state['prev_mention_lengths'][reset] = 0 + self._state['prev_t'][reset] = 0 + + # Reset the dynamic embeddings and lstm self._dynamic_embeddings.reset_states(reset) + self._rnn.reset(reset) def detach_states(self): """Detaches the model's state to enforce truncated backpropagation.""" @@ -526,7 +791,7 @@ def detach_states(self): @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { - 'et_acc': self._entity_type_accuracy.get_metric(reset), - 'eid_acc': self._entity_id_accuracy.get_metric(reset), - 'ml_acc': self._mention_length_accuracy.get_metric(reset) + 'et_acc': self._entity_type_accuracy.get_metric(reset), + 'eid_acc': self._entity_id_accuracy.get_metric(reset), + 'ml_acc': self._mention_length_accuracy.get_metric(reset) } diff --git a/kglm/models/entity_nlm.py b/kglm/models/entity_nlm.py index 63b7ffd..2b699c1 100644 --- a/kglm/models/entity_nlm.py +++ b/kglm/models/entity_nlm.py @@ -17,7 +17,7 @@ from torch.nn import Parameter import torch.nn.functional as F -from kglm.modules import DynamicEmbedding, WeightDrop +from kglm.modules import DynamicEmbedding, WeightDroppedLstm from kglm.training.metrics import Ppl # from kglm.training.metrics import Perplexity, UnknownPenalizedPerplexity @@ -81,22 +81,10 @@ def __init__(self, self._tie_weights = tie_weights self._variational_dropout_rate = variational_dropout_rate self._dropout_rate = dropout_rate - - # Rnn Encoders. - rnns: List[torch.nn.Module] = [] - for i in range(num_layers): - if i == 0: - input_size = embedding_dim - else: - input_size = hidden_size - if (i == num_layers - 1): - output_size = embedding_dim - else: - output_size = hidden_size - rnns.append(torch.nn.LSTM(input_size, output_size, batch_first=True)) - rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=variational_dropout_rate) for rnn in rnns] - self.rnns = torch.nn.ModuleList(rnns) - + self._rnn = WeightDroppedLstm(num_layers=num_layers, + input_embedding_dim=embedding_dim, + hidden_size=hidden_size, + dropout=variational_dropout_rate) self._state: Optional[StateDict] = None # Input variational dropout @@ -108,35 +96,27 @@ def __init__(self, out_features=2, bias=False) self._dynamic_embeddings = DynamicEmbedding(embedding_dim=embedding_dim, - max_embeddings=max_embeddings) + max_embeddings=max_embeddings, + tied_weight=self._entity_type_projection.weight) # For mention length prediction self._mention_length_projection = torch.nn.Linear(in_features=2*embedding_dim, out_features=max_mention_length) # For next word prediction - self._dummy_context_embedding = Parameter(F.normalize(torch.randn(1, embedding_dim))) # TODO: Maybe squeeze self._entity_output_projection = torch.nn.Linear(in_features=embedding_dim, out_features=embedding_dim, bias=False) - self._context_output_projection = torch.nn.Linear(in_features=embedding_dim, - out_features=embedding_dim, - bias=False) self._vocab_projection = torch.nn.Linear(in_features=embedding_dim, out_features=vocab.get_vocab_size('tokens')) if tie_weights: self._vocab_projection.weight = self._text_field_embedder._token_embedders['tokens'].weight # pylint: disable=W0212 - # self._perplexity = Perplexity() - # self._unknown_penalized_perplexity = UnknownPenalizedPerplexity(self.vocab) self._entity_type_accuracy = CategoricalAccuracy() self._entity_id_accuracy = CategoricalAccuracy() self._mention_length_accuracy = CategoricalAccuracy() self._perplexity = Ppl() - if tie_weights: - self._vocab_projection.weight = self._text_field_embedder._token_embedders['tokens'].weight # pylint: disable=W0212 - initializer(self) @overrides @@ -241,32 +221,16 @@ def _forward_loop(self, mention_lengths = torch.cat((self._state['prev_mention_lengths'], mention_lengths), dim=1) contexts = self._state['prev_contexts'] sequence_length += 1 + prev_t = self._state['prev_t'] else: - contexts = self._dummy_context_embedding.repeat(batch_size, 1) + contexts = tokens['tokens'].new_zeros(batch_size, self._embedding_dim, dtype=torch.float32) + prev_t = tokens['tokens'].new_zeros(batch_size) # Embed tokens and get RNN hidden state. mask = get_text_field_mask(tokens).byte() embeddings = self._text_field_embedder(tokens) embeddings = self._variational_dropout(embeddings) - - current_input = embeddings - hidden_list = [] - for layer, rnn in enumerate(self.rnns): - # Retrieve previous hidden state for layer. - if self._state is not None: - prev_hidden = self._state['layer_%i' % layer] - else: - prev_hidden = None - # Forward-pass. - output, hidden = rnn(current_input, prev_hidden) - output = output.contiguous() - # Update hidden state for layer. - hidden = tuple(h.detach() for h in hidden) - hidden_list.append(hidden) - current_input = output - hidden = current_input - - self._state = {'layer_%i' % i: h for i, h in enumerate(hidden_list)} + hidden = self._rnn(embeddings[:,:-1]) # Otherwise will double count on splits # Initialize losses entity_type_loss = 0.0 @@ -285,6 +249,7 @@ def _forward_loop(self, current_entity_ids = entity_ids[:, timestep] current_mention_lengths = mention_lengths[:, timestep] current_hidden = self._dropout(hidden[:, timestep]) + current_mask = mask[:, timestep] next_entity_types = entity_types[:, timestep + 1] next_entity_ids = entity_ids[:, timestep + 1] @@ -301,7 +266,7 @@ def _forward_loop(self, # We also perform updates of the currently observed entities. self._dynamic_embeddings.update_embeddings(hidden=current_hidden, update_indices=current_entity_ids, - timestep=timestep, + timestep=prev_t, mask=current_entity_types) # This part is a little counter-intuitive. Because the above code adds a new embedding @@ -314,11 +279,10 @@ def _forward_loop(self, # require access to the **next** hidden state, which does not exist during generation). next_entity_ids = next_entity_ids.clone() # This prevents mutating the source data. next_entity_ids[next_entity_ids == self._dynamic_embeddings.num_embeddings] = 0 - # We only predict the types / ids / lengths of the next mention if we are not currently # in the process of generating it (e.g. if the current remaining mention length is 1). # Indexing / masking with ``predict_all`` makes it possible to do this in batch. - predict_all = (current_mention_lengths == 1) & next_mask + predict_all = (current_mention_lengths == 0) & next_mask & current_mask if predict_all.any(): # Equation 3 in the paper. @@ -340,7 +304,7 @@ def _forward_loop(self, if predict_em.any(): # Equation 4 in the paper. entity_id_prediction_outputs = self._dynamic_embeddings(hidden=current_hidden, - timestep=timestep, + timestep=prev_t, target=next_entity_ids, mask=predict_em) _entity_id_loss = -entity_id_prediction_outputs['loss'] @@ -376,21 +340,27 @@ def _forward_loop(self, # Always predict the next word. This is done using the hidden state and contextual bias. entity_embeddings = self._dynamic_embeddings.embeddings[next_entity_types, next_entity_ids[next_entity_types]] - entity_embeddings = self._entity_output_projection(entity_embeddings) + # entity_embeddings = self._entity_output_projection(entity_embeddings) context_embeddings = contexts[~next_entity_types] - context_embeddings = self._context_output_projection(context_embeddings) + # context_embeddings = self._context_output_projection(context_embeddings) - # The checks in the following block of code are required to prevent adding empty - # tensors to vocab_features (which causes a floating point error). - vocab_features = current_hidden.clone() + # Combine entity and context embeddings + combined_embeddings = torch.zeros_like(current_hidden) if next_entity_types.any(): - vocab_features[next_entity_types] = vocab_features[next_entity_types] + entity_embeddings + combined_embeddings[next_entity_types] = entity_embeddings if (~next_entity_types).any(): - vocab_features[~next_entity_types] = vocab_features[~next_entity_types] + context_embeddings - vocab_logits = self._vocab_projection(vocab_features[next_mask]) + combined_embeddings[~next_entity_types] = context_embeddings + + # Project + combined_embeddings_proj = self._entity_output_projection(combined_embeddings) + + # The checks in the following block of code are required to prevent adding empty + # tensors to vocab_features (which causes a floating point error). + vocab_features = current_hidden + combined_embeddings_proj + vocab_logits = self._vocab_projection(vocab_features[next_mask & current_mask]) vocab_logp = F.log_softmax(vocab_logits, -1) - _vocab_loss = -vocab_logp.gather(-1, next_tokens[next_mask].unsqueeze(-1)) - logp[next_mask] += -_vocab_loss.squeeze() + _vocab_loss = -vocab_logp.gather(-1, next_tokens[next_mask & current_mask].unsqueeze(-1)) + logp[next_mask & current_mask] += -_vocab_loss.squeeze() # _vocab_loss = F.cross_entropy(vocab_logits, next_tokens, reduction='none') # _vocab_loss = _vocab_loss * next_mask.float() @@ -405,7 +375,13 @@ def _forward_loop(self, # mask=next_mask.float()) # Lastly update contexts - contexts = current_hidden + contexts = combined_embeddings + prev_t += 1 + + # And to be super careful, we want to reset any rnn hidden states + # if the current token is padding (this could impact performance at + # the start of a sequence). + self._rnn.reset(~current_mask) self._perplexity(vocab_loss, mask.sum()) @@ -422,6 +398,7 @@ def _forward_loop(self, logger.debug('Vocab loss: %0.4f', vocab_loss) total_loss = entity_type_loss + entity_id_loss + mention_length_loss + vocab_loss + output_dict = { 'entity_type_loss': entity_type_loss, 'entity_id_loss': entity_id_loss, @@ -432,32 +409,34 @@ def _forward_loop(self, 'penalized_logp': -total_loss * mask.sum() } - # Update the model state - self._state['prev_tokens'] = {field: tokens[field][:, -1].unsqueeze(1).detach() for field in tokens} - self._state['prev_entity_types'] = entity_types[:, -1].unsqueeze(1).detach() - self._state['prev_entity_ids'] = entity_ids[:, -1].unsqueeze(1).detach() - self._state['prev_mention_lengths'] = mention_lengths[:, -1].unsqueeze(1).detach() - self._state['prev_contexts'] = contexts.detach() + self._state = { + 'prev_tokens': {field: tokens[field][:, -1].unsqueeze(1).detach() for field in tokens}, + 'prev_entity_types': entity_types[:, -1].unsqueeze(1).detach(), + 'prev_entity_ids': entity_ids[:, -1].unsqueeze(1).detach(), + 'prev_mention_lengths': mention_lengths[:, -1].unsqueeze(1).detach(), + 'prev_contexts': contexts.detach(), + 'prev_t': prev_t.detach() + } return output_dict def reset_states(self, reset: torch.ByteTensor) -> None: """Resets the model's internals. Should be called at the start of a new batch.""" + if reset.all(): + self._state = None if reset.any() and (self._state is not None): # Zero out any previous elements - self._state['prev_entity_types'][reset].zero_() - self._state['prev_entity_ids'][reset].zero_() - self._state['prev_mention_lengths'][reset].zero_() - self._state['prev_contexts'][reset].zero_() - # Zero out the hidden state - for layer in range(self._num_layers): - h, c = self._state['layer_%i' % layer] - h[:, reset, :] = torch.zeros_like(h[:, reset, :]) - c[:, reset, :] = torch.zeros_like(c[:, reset, :]) - self._state['layer_%i' % layer] = (h, c) + for field in self._state['prev_tokens']: + self._state['prev_tokens'][field][reset] = 0 + self._state['prev_entity_types'][reset] = 0 + self._state['prev_entity_ids'][reset] = 0 + self._state['prev_mention_lengths'][reset] = 0 + self._state['prev_t'][reset] = 0 + self._state['prev_contexts'][reset] = 0.0 # Reset the dynamic embeddings self._dynamic_embeddings.reset_states(reset) + self._rnn.reset(reset) def detach_states(self): """Detaches the model's state to enforce truncated backpropagation.""" @@ -470,13 +449,11 @@ def train(self, mode=True): # batch sizes (e.g. the `reset` tensor will not be the right size). In future # implementations this should be handled more robustly. super().train(mode) - self._state = None @overrides def eval(self): # TODO: See train. super().eval() - self._state = None @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: @@ -488,3 +465,5 @@ def get_metrics(self, reset: bool = False) -> Dict[str, float]: 'ml_acc': self._mention_length_accuracy.get_metric(reset), 'ppl': self._perplexity.get_metric(reset) } + + hidden = self._rnn(embeddings) diff --git a/kglm/models/kglm.py b/kglm/models/kglm.py index dcf4cbb..27ade21 100644 --- a/kglm/models/kglm.py +++ b/kglm/models/kglm.py @@ -15,8 +15,8 @@ import torch.nn.functional as F from kglm.data import AliasDatabase -from kglm.modules import ( - embedded_dropout, LockedDropout, WeightDrop, KnowledgeGraphLookup, RecentEntities) +from kglm.modules import (embedded_dropout, LockedDropout, WeightDroppedLstm, + KnowledgeGraphLookup, RecentEntities) from kglm.nn.util import nested_enumerate, parallel_sample from kglm.training.metrics import Ppl @@ -86,20 +86,12 @@ def __init__(self, token_embedding_dim = token_embedder.get_output_dim() self.entity_embedding_dim = entity_embedding_dim self.token_embedding_dim = token_embedding_dim - - rnns: List[torch.nn.Module] = [] - for i in range(num_layers): - if i == 0: - input_size = token_embedding_dim - else: - input_size = hidden_size - if (i == num_layers - 1): - output_size = token_embedding_dim + 2 * entity_embedding_dim - else: - output_size = hidden_size - rnns.append(torch.nn.LSTM(input_size, output_size, batch_first=True)) - rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop) for rnn in rnns] - self.rnns = torch.nn.ModuleList(rnns) + rnn_output_dim = token_embedding_dim + 2 * entity_embedding_dim + self._rnn = WeightDroppedLstm(num_layers=num_layers, + input_embedding_dim=self.token_embedding_dim, + hidden_size=self._hidden_size, + output_embedding_dim=rnn_output_dim, + dropout=self._wdrop) # Various linear transformations. self._fc_mention_type = torch.nn.Linear( @@ -129,8 +121,6 @@ def __init__(self, if tie_weights: self._fc_generate.weight = self._token_embedder.weight - self._state: Optional[Dict[str, Any]] = None - # Metrics self._unk_index = vocab.get_token_index(DEFAULT_OOV_TOKEN) self._unk_penalty = math.log(vocab.get_vocab_size('tokens_unk')) @@ -293,13 +283,7 @@ def sample(self, alias_database.tensorize(vocab=self.vocab) # Reset - if reset.any() and (self._state is not None): - for layer in range(self._num_layers): - h, c = self._state['layer_%i' % layer] - h[:, reset, :] = torch.zeros_like(h[:, reset, :]) - c[:, reset, :] = torch.zeros_like(c[:, reset, :]) - self._state['layer_%i' % layer] = (h, c) - self._recent_entities.reset(reset) + self.reset_states(reset) # Get source tokens source_tokens = source['tokens'] @@ -405,13 +389,7 @@ def forward(self, # pylint: disable=arguments-differ alias_database.tensorize(vocab=self.vocab) # Reset the model if needed - if reset.any() and (self._state is not None): - for layer in range(self._num_layers): - h, c = self._state['layer_%i' % layer] - h[:, reset, :] = torch.zeros_like(h[:, reset, :]) - c[:, reset, :] = torch.zeros_like(c[:, reset, :]) - self._state['layer_%i' % layer] = (h, c) - self._recent_entities.reset(reset) + self.reset_states(reset) if target is not None: output_dict = self._forward_loop( @@ -468,8 +446,8 @@ def _forward_loop(self, # Predict whether or not the next token will be an entity mention, and if so which type. mention_type_loss = self._mention_type_loss(encoded_token, mention_type, target_mask) - self._avg_mention_type_loss(float(mention_type_loss)) - logger.debug('mention type loss: %0.4f', mention_type_loss) + logger.debug('mention loss: %0.4f', mention_type_loss.sum() / (target_mask.sum().float() + 1e-13)) + self._avg_mention_type_loss(float(mention_type_loss.sum() / (target_mask.sum().float() + 1e-13))) # For new mentions, predict which entity (among those in the supplied shortlist) will be # mentioned. @@ -486,8 +464,8 @@ def _forward_loop(self, None, target_mask) - self._avg_new_entity_loss(float(new_entity_loss)) - logger.debug('new entity loss: %0.4f', new_entity_loss) + logger.debug('new ent loss: %0.4f', new_entity_loss.sum() / (target_mask.sum().float() + 1e-13)) + self._avg_new_entity_loss(float(new_entity_loss.sum() / (target_mask.sum().float() + 1e-13))) # For derived mentions, first predict which parent(s) to expand... knowledge_graph_entity_loss = self._knowledge_graph_entity_loss(encoded_head, @@ -496,8 +474,8 @@ def _forward_loop(self, entity_ids, parent_ids, target_mask) - self._avg_knowledge_graph_entity_loss(float(knowledge_graph_entity_loss)) - logger.debug('kg entity loss: %0.4f', knowledge_graph_entity_loss) + self._avg_knowledge_graph_entity_loss(float(knowledge_graph_entity_loss.sum() / (target_mask.sum().float() + 1e-13))) + logger.debug('kg loss: %0.4f', knowledge_graph_entity_loss.sum() / (target_mask.sum().float() + 1e-13)) # Predict generation-mode scores. Note: these are W.R.T to entity_ids since we need the embedding. generate_scores = self._generate_scores(encoded_token, entity_ids) @@ -514,11 +492,13 @@ def _forward_loop(self, target_mask, alias_inds, entity_ids.gt(0)) + logger.debug('vocab loss: %0.4f', vocab_loss.sum() / (target_mask.sum().float() + 1e-13)) # Compute total loss. Also compute logp (needed for importance sampling evaluation). - loss = vocab_loss + mention_type_loss + new_entity_loss + knowledge_graph_entity_loss - logp = -(vocab_loss + mention_type_loss + new_entity_loss + knowledge_graph_entity_loss) * target_mask.sum() - penalized_logp = -(penalized_vocab_loss + mention_type_loss + new_entity_loss + knowledge_graph_entity_loss) * target_mask.sum() + loss = (vocab_loss + mention_type_loss + new_entity_loss + knowledge_graph_entity_loss).sum() / (target_mask.sum().float() + 1e-13) + logger.debug('loss: %0.4f', loss) + logp = -(vocab_loss + mention_type_loss + new_entity_loss + knowledge_graph_entity_loss) + penalized_logp = -(penalized_vocab_loss + mention_type_loss + new_entity_loss + knowledge_graph_entity_loss) # Activation regularization if self._alpha: @@ -680,42 +660,17 @@ def decode(self, output: Dict[str, Any]): return output def _encode_source(self, source: Dict[str, torch.Tensor]) -> torch.Tensor: - - # Extract and embed source tokens. + # Extract, embed and encode source tokens. source_embeddings = embedded_dropout( embed=self._token_embedder, words=source, dropout=self._dropoute if self.training else 0) source_embeddings = self._locked_dropout(source_embeddings, self._dropouti) + encoded_raw = self._rnn(source_embeddings) + encoded = self._locked_dropout(encoded_raw) - # Encode. - current_input = source_embeddings - hidden_states = [] - for layer, rnn in enumerate(self.rnns): - # Retrieve previous hidden state for layer. - if self._state is not None: - prev_hidden = self._state['layer_%i' % layer] - else: - prev_hidden = None - # Forward-pass. - output, hidden = rnn(current_input, prev_hidden) - output = output.contiguous() - # Update hidden state for layer. - hidden = tuple(h.detach() for h in hidden) - hidden_states.append(hidden) - # Apply dropout. - if layer == self._num_layers - 1: - dropped_output = self._locked_dropout(output, self._dropout) - else: - dropped_output = self._locked_dropout(output, self._dropouth) - current_input = dropped_output - encoded = current_input - - alpha_loss = dropped_output.pow(2).mean() - beta_loss = (output[:, 1:] - output[:, :-1]).pow(2).mean() - - # Update state. - self._state = {'layer_%i' % i: h for i, h in enumerate(hidden_states)} + alpha_loss = encoded.pow(2).mean() + beta_loss = (encoded_raw[:, 1:] - encoded_raw[:, :-1]).pow(2).mean() return encoded, alpha_loss, beta_loss @@ -728,9 +683,11 @@ def _mention_type_loss(self, entity mention. """ logits = self._fc_mention_type(encoded) - mention_loss = sequence_cross_entropy_with_logits(logits, mention_type, mask, - average='token') - + mention_logp = F.log_softmax(logits, -1) + mention_loss = -mention_logp.gather(-1, mention_type.unsqueeze(-1)).squeeze() + mention_loss = mention_loss * mask.float() + # mention_loss = sequence_cross_entropy_with_logits(logits, mention_type, mask, + # average='token') # if not self.training: self._new_mention_f1(predictions=logits, @@ -740,7 +697,7 @@ def _mention_type_loss(self, gold_labels=mention_type, mask=mask) - return mention_loss + return mention_loss.sum(-1) def _new_entity_logits(self, encoded: torch.Tensor, @@ -778,18 +735,17 @@ def _new_entity_loss(self, log_probs = masked_log_softmax(logits, shortlist_mask) else: log_probs = F.log_softmax(logits, dim=-1) - target_log_probs = torch.gather(log_probs, -1, target_inds.unsqueeze(-1)).squeeze(-1) - target_log_probs = target_log_probs * target_mask.float() - # Also don't predict on non-mentions + target_loss = -log_probs.gather( -1, target_inds.unsqueeze(-1)).squeeze(-1) + target_loss = target_loss * target_mask.float() mentions = ~entity_ids.eq(0) - target_log_probs = target_log_probs * mentions.float() + target_loss = target_loss * mentions.float() # self._new_entity_accuracy(predictions=log_probs[mask], # gold_labels=target_inds[mask]) # self._new_entity_accuracy20(predictions=log_probs[mask], # gold_labels=target_inds[mask]) - return -target_log_probs.sum() / (target_mask.sum() + 1e-13) + return target_loss.sum(-1) # / (target_mask.sum(-1).float() + 1e-13) def _parent_log_probs(self, encoded_head: torch.Tensor, @@ -902,7 +858,7 @@ def _knowledge_graph_entity_loss(self, self._parent_ppl(-torch.logsumexp(parent_log_probs, dim=-1)[mask].sum(), mask.float().sum()) self._relation_ppl(-torch.logsumexp(relation_log_probs, dim=-1)[mask].sum(), mask.float().sum()) # Lastly return the tokenwise average loss - return -target_log_probs.sum() / (target_mask.sum() + 1e-13) + return -target_log_probs.sum(-1) # / (target_mask.sum(-1) + 1e-13) def _generate_scores(self, encoded: torch.Tensor, @@ -1014,13 +970,13 @@ def _vocab_loss(self, flattened_mask = flattened_mask.squeeze() # Zero out padding loss combined_log_probs_extended_vocab = combined_log_probs_extended_vocab * flattened_mask.float() - vocab_loss = -combined_log_probs_extended_vocab.sum() / (mask.sum() + 1e-13) + vocab_loss = -combined_log_probs_extended_vocab.view(batch_size, sequence_length).sum(-1)# / (mask.sum(-1) + 1e-13) # Unknown penalty - only applies to non-copied unks true_unks = unks.squeeze() & ~copied.squeeze() & flattened_mask penalized_log_probs = combined_log_probs_extended_vocab - self._unk_penalty * true_unks.float() penalized_log_probs[~flattened_mask] = 0 - penalized_vocab_loss = -penalized_log_probs.sum() / (mask.sum() + 1e-13) + penalized_vocab_loss = -penalized_log_probs.view(batch_size, sequence_length).sum(-1)# / (mask.sum(-1) + 1e-13) # PERPLEXITY ### # Our perplexity terms are computed using the log probs computed w.r.t the source @@ -1061,13 +1017,13 @@ def train(self, mode=True): # batch sizes (e.g. the `reset` tensor will not be the right size). In future # implementations this should be handled more robustly. super().train(mode) - self._state = None + self._rnn.reset() @overrides def eval(self): # TODO: See train. super().eval() - self._state = None + self._rnn.reset() def get_metrics(self, reset: bool = False) -> Dict[str, float]: out = { @@ -1095,3 +1051,6 @@ def get_metrics(self, reset: bool = False) -> Dict[str, float]: out['relation_ppl'] = self._relation_ppl.get_metric(reset) return out + def reset_states(self, reset): + self._rnn.reset(reset) + self._recent_entities.reset(reset) diff --git a/kglm/models/kglm_disc.py b/kglm/models/kglm_disc.py index f714cc6..da9c0f8 100644 --- a/kglm/models/kglm_disc.py +++ b/kglm/models/kglm_disc.py @@ -1,6 +1,8 @@ import logging +from copy import deepcopy +from collections import namedtuple import math -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from allennlp.data.vocabulary import Vocabulary, DEFAULT_OOV_TOKEN from allennlp.modules import TextFieldEmbedder, Seq2SeqEncoder @@ -9,12 +11,13 @@ from allennlp.nn.util import (get_text_field_mask, masked_log_softmax, masked_softmax, sequence_cross_entropy_with_logits) from allennlp.training.metrics import Average, CategoricalAccuracy, F1Measure, SequenceAccuracy +import numpy as np from overrides import overrides import torch import torch.nn.functional as F from kglm.data import AliasDatabase -from kglm.modules import (embedded_dropout, LockedDropout, WeightDrop, KnowledgeGraphLookup, +from kglm.modules import (embedded_dropout, LockedDropout, WeightDroppedLstm, KnowledgeGraphLookup, RecentEntities) from kglm.nn.util import nested_enumerate, parallel_sample from kglm.training.metrics import Ppl @@ -22,6 +25,14 @@ logger = logging.getLogger(__name__) +# Decoding from the KGLM discriminator requires ensuring that: +# * New mentions cannot be of recently mentioned entities. +# * Related mentions must be related to a recently mentioned entity. +# * Ongoing mention cannot continue non-mentions. +# The following structure tracks this information when performing beam search. +KglmBeamState = namedtuple('KglmBeamState', ['recent_entities', 'ongoing']) + + @Model.register('kglm-disc') class KglmDisc(Model): """ @@ -83,20 +94,12 @@ def __init__(self, token_embedding_dim = token_embedder.get_output_dim() self.entity_embedding_dim = entity_embedding_dim self.token_embedding_dim = token_embedding_dim - - rnns: List[torch.nn.Module] = [] - for i in range(num_layers): - if i == 0: - input_size = token_embedding_dim - else: - input_size = hidden_size - if (i == num_layers - 1): - output_size = token_embedding_dim + 2 * entity_embedding_dim - else: - output_size = hidden_size - rnns.append(torch.nn.LSTM(input_size, output_size, batch_first=True)) - rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop) for rnn in rnns] - self.rnns = torch.nn.ModuleList(rnns) + rnn_output_dim = token_embedding_dim + 2 * entity_embedding_dim + self._rnn = WeightDroppedLstm(num_layers=num_layers, + input_embedding_dim=self.token_embedding_dim, + hidden_size=self._hidden_size, + output_embedding_dim=rnn_output_dim, + dropout=self._wdrop) # Various linear transformations. self._fc_mention_type = torch.nn.Linear( @@ -111,8 +114,6 @@ def __init__(self, if tie_weights: self._fc_new_entity.weight = self._entity_embedder.weight - self._state: Optional[Dict[str, Any]] = None - # Metrics self._unk_index = vocab.get_token_index(DEFAULT_OOV_TOKEN) self._unk_penalty = math.log(vocab.get_vocab_size('tokens_unk')) @@ -128,6 +129,44 @@ def __init__(self, initializer(self) + @overrides + def forward(self, # pylint: disable=arguments-differ + source: Dict[str, torch.Tensor], + reset: torch.Tensor, + metadata: List[Dict[str, Any]], + mention_type: torch.Tensor = None, + raw_entity_ids: Dict[str, torch.Tensor] = None, + entity_ids: Dict[str, torch.Tensor] = None, + parent_ids: Dict[str, torch.Tensor] = None, + relations: Dict[str, torch.Tensor] = None, + shortlist: Dict[str, torch.Tensor] = None, + shortlist_inds: torch.Tensor = None, **kwargs) -> Dict[str, torch.Tensor]: + + # Tensorize the alias_database - this will only perform the operation once. + alias_database = metadata[0]['alias_database'] + alias_database.tensorize(vocab=self.vocab) + + # Reset the model if needed + self.reset_states(reset) + + if entity_ids is not None: + output_dict = self._forward_loop( + source=source, + alias_database=alias_database, + mention_type=mention_type, + raw_entity_ids=raw_entity_ids, + entity_ids=entity_ids, + parent_ids=parent_ids, + relations=relations, + shortlist=shortlist, + shortlist_inds=shortlist_inds) + else: + # TODO: Figure out what we want here - probably to do some king of inference on + # entities / mention types. + output_dict = {} + + return output_dict + def sample(self, source: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor], @@ -135,6 +174,8 @@ def sample(self, metadata: Dict[str, Any], alias_copy_inds: torch.Tensor, shortlist: Dict[str, torch.Tensor] = None, + temperature: float = 1.0, + offset: bool = False, **kwargs) -> Dict[str, Any]: # **kwargs intended to eat the other fields if they are provided. """ Sampling annotations for the generative model. Note that unlike forward, this function @@ -146,35 +187,34 @@ def sample(self, alias_database.tensorize(vocab=self.vocab) # Reset the model if needed - if reset.any() and (self._state is not None): - for layer in range(self._num_layers): - h, c = self._state['layer_%i' % layer] - h[:, reset, :] = torch.zeros_like(h[:, reset, :]) - c[:, reset, :] = torch.zeros_like(c[:, reset, :]) - self._state['layer_%i' % layer] = (h, c) - self._recent_entities.reset(reset) - - logp = 0.0 + self.reset_states(reset) mask = get_text_field_mask(target).byte() + batch_size = mask.shape[0] # We encode the target tokens (**not** source) since the discriminitative model makes # predictions on the current token, but the generative model expects labels for the # **next** (e.g. target) token! - encoded, *_ = self._encode_source(target['tokens']) + if not offset: + encoded, *_ = self._encode_source(target['tokens']) + else: + encoded, *_ = self._encode_source(source['tokens']) splits = [self.token_embedding_dim] + [self.entity_embedding_dim] * 2 encoded_token, encoded_head, encoded_relation = encoded.split(splits, dim=-1) + # logp = 0.0 + logp = encoded.new_zeros(batch_size) + # Compute new mention logits - mention_logits = self._fc_mention_type(encoded_token) + mention_logits = self._fc_mention_type(encoded_token) / temperature mention_probs = F.softmax(mention_logits, dim=-1) mention_type = parallel_sample(mention_probs) - mention_logp = mention_probs.gather(-1, mention_type.unsqueeze(-1)).log() - mention_logp[~mask] = 0 - mention_logp = mention_logp.sum() + _mention_logp = mention_probs.gather(-1, mention_type.unsqueeze(-1)).log() + _mention_logp[~mask] = 0 + mention_logp = _mention_logp.view(batch_size, -1).sum(-1) # Compute entity logits new_entity_mask = mention_type.eq(1) - new_entity_logits = self._new_entity_logits(encoded_head + encoded_relation, shortlist) + new_entity_logits = self._new_entity_logits(encoded_head + encoded_relation, shortlist) / temperature if self._use_shortlist: # If using shortlist, then samples are indexed w.r.t the shortlist and entity_ids must be looked up shortlist_mask = get_text_field_mask(shortlist) @@ -188,7 +228,7 @@ def sample(self, _new_entity_logp = new_entity_probs.gather(-1, shortlist_inds.unsqueeze(-1)).log() new_entity_samples = shortlist['entity_ids'].gather(1, shortlist_inds) else: - new_entity_logits = new_entity_logits + new_entity_logits[:,:,:4] = -1e32 # A new entity mustn't be padding, unknown, or a literal # If not using shortlist, then samples are indexed w.r.t to the global vocab new_entity_probs = F.softmax(new_entity_logits, dim=-1) new_entity_samples = parallel_sample(new_entity_probs) @@ -197,7 +237,7 @@ def sample(self, # Zero out masked tokens and non-new entity predictions _new_entity_logp[~mask] = 0 _new_entity_logp[~new_entity_mask] = 0 - new_entity_logp = _new_entity_logp.sum() + new_entity_logp = _new_entity_logp.view(batch_size, -1).sum(-1) # Start filling in the entity ids entity_ids = torch.zeros_like(target['tokens']) @@ -213,7 +253,7 @@ def sample(self, # Derived mentions need to be computed sequentially. parent_ids = torch.zeros_like(target['tokens']).unsqueeze(-1) derived_entity_mask = mention_type.eq(2) - derived_entity_logp = 0.0 + derived_entity_logp = torch.zeros_like(new_entity_logp) sequence_length = target['tokens'].shape[1] for i in range(sequence_length): @@ -235,7 +275,7 @@ def sample(self, # Compute logits w.r.t **current** hidden state only current_head_encoding = encoded_head[:, i].unsqueeze(1) - selection_logits = torch.bmm(current_head_encoding, candidate_embeddings.transpose(1, 2)) + selection_logits = torch.bmm(current_head_encoding, candidate_embeddings.transpose(1, 2)) / temperature selection_probs = masked_softmax(selection_logits, candidate_mask) # Only sample if the is at least one viable candidate (e.g. if a sampling distribution @@ -254,7 +294,7 @@ def sample(self, parent_logp[viable_candidate_mask] = viable_logp.squeeze(-1) parent_ids[current_mask, i] = _parent_ids[current_mask] # TODO: Double-check - derived_entity_logp += parent_logp[current_mask].sum() + derived_entity_logp[current_mask] += parent_logp[current_mask].squeeze(-1) ## SAMPLE RELATIONS ## @@ -270,7 +310,7 @@ def sample(self, # Compute the score for each relation w.r.t the current encoding. NOTE: In the loss # code index has a slice. We don't need that here since there is always a # **single** parent. - logits = torch.mv(relation_embedding, current_relation_encoding[index]) + logits = torch.mv(relation_embedding, current_relation_encoding[index]) / temperature # Convert to probability tail_probs = F.softmax(logits, dim=-1) # Sample @@ -278,7 +318,7 @@ def sample(self, # Get logp. Ignoring the current_mask here is **super** dodgy, but since we forced # null parents to zero we shouldn't be accumulating probabilities for unused predictions. tail_logp = tail_probs.gather(-1, tail_sample).log() - derived_entity_logp += tail_logp.sum() # Sum is redundant, just need it to make logp a scalar + derived_entity_logp[index[:-1]] += tail_logp.sum() # Sum is redundant, just need it to make logp a scalar # Map back to raw id raw_tail_id = tail_id_lookup[tail_sample] @@ -329,90 +369,46 @@ def sample(self, logp = mention_logp + new_entity_logp + derived_entity_logp return {'sample': sample, 'logp': logp} - @overrides - def forward(self, # pylint: disable=arguments-differ - source: Dict[str, torch.Tensor], - reset: torch.Tensor, - metadata: List[Dict[str, Any]], - mention_type: torch.Tensor = None, - raw_entity_ids: Dict[str, torch.Tensor] = None, - entity_ids: Dict[str, torch.Tensor] = None, - parent_ids: Dict[str, torch.Tensor] = None, - relations: Dict[str, torch.Tensor] = None, - shortlist: Dict[str, torch.Tensor] = None, - shortlist_inds: torch.Tensor = None) -> Dict[str, torch.Tensor]: - - # Tensorize the alias_database - this will only perform the operation once. - alias_database = metadata[0]['alias_database'] - alias_database.tensorize(vocab=self.vocab) - - # Reset the model if needed - if reset.any() and (self._state is not None): - for layer in range(self._num_layers): - h, c = self._state['layer_%i' % layer] - h[:, reset, :] = torch.zeros_like(h[:, reset, :]) - c[:, reset, :] = torch.zeros_like(c[:, reset, :]) - self._state['layer_%i' % layer] = (h, c) - self._recent_entities.reset(reset) - - if entity_ids is not None: - output_dict = self._forward_loop( - source=source, - alias_database=alias_database, - mention_type=mention_type, - raw_entity_ids=raw_entity_ids, - entity_ids=entity_ids, - parent_ids=parent_ids, - relations=relations, - shortlist=shortlist, - shortlist_inds=shortlist_inds) - else: - # TODO: Figure out what we want here - probably to do some king of inference on - # entities / mention types. - output_dict = {} + def get_raw_entity_ids(self, entity_ids: torch.LongTensor) -> torch.LongTensor: + raw_entity_ids = torch.zeros_like(entity_ids) + for *index, entity_id in nested_enumerate(entity_ids.tolist()): + token = self.vocab.get_token_from_index(entity_id, 'entity_ids') + raw_entity_id = self.vocab.get_token_index(token, 'raw_entity_ids') + raw_entity_ids[tuple(index)] = raw_entity_id + return raw_entity_ids - return output_dict + def get_entity_ids(self, raw_entity_ids: torch.LongTensor) -> torch.LongTensor: + entity_ids = torch.zeros_like(raw_entity_ids) + for *index, raw_entity_id in nested_enumerate(raw_entity_ids.tolist()): + token = self.vocab.get_token_from_index(raw_entity_id, 'raw_entity_ids') + entity_id = self.vocab.get_token_index(token, 'entity_ids') + entity_ids[tuple(index)] = entity_id + return entity_ids def _encode_source(self, source: Dict[str, torch.Tensor]) -> torch.Tensor: - # Extract and embed source tokens. + # Extract, embed and encode source tokens. source_embeddings = embedded_dropout( embed=self._token_embedder, words=source, dropout=self._dropoute if self.training else 0) source_embeddings = self._locked_dropout(source_embeddings, self._dropouti) + encoded_raw = self._rnn(source_embeddings) + encoded = self._locked_dropout(encoded_raw) - # Encode. - current_input = source_embeddings - hidden_states = [] - for layer, rnn in enumerate(self.rnns): - # Retrieve previous hidden state for layer. - if self._state is not None: - prev_hidden = self._state['layer_%i' % layer] - else: - prev_hidden = None - # Forward-pass. - output, hidden = rnn(current_input, prev_hidden) - output = output.contiguous() - # Update hidden state for layer. - hidden = tuple(h.detach() for h in hidden) - hidden_states.append(hidden) - # Apply dropout. - if layer == self._num_layers - 1: - dropped_output = self._locked_dropout(output, self._dropout) - else: - dropped_output = self._locked_dropout(output, self._dropouth) - current_input = dropped_output - encoded = current_input - - alpha_loss = dropped_output.pow(2).mean() - beta_loss = (output[:, 1:] - output[:, :-1]).pow(2).mean() - - # Update state. - self._state = {'layer_%i' % i: h for i, h in enumerate(hidden_states)} + alpha_loss = encoded.pow(2).mean() + beta_loss = (encoded_raw[:, 1:] - encoded_raw[:, :-1]).pow(2).mean() return encoded, alpha_loss, beta_loss + def _mention_type_loss(self, + encoded: torch.Tensor, + mention_type: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """ + Computes the loss for predicting whether or not the the next token will be part of an + entity mention. + """ def _mention_type_loss(self, encoded: torch.Tensor, mention_type: torch.Tensor, @@ -422,8 +418,12 @@ def _mention_type_loss(self, entity mention. """ logits = self._fc_mention_type(encoded) - mention_type_loss = sequence_cross_entropy_with_logits(logits, mention_type, mask, - average='token') + mention_logp = F.log_softmax(logits, -1) + mention_loss = -mention_logp.gather(-1, mention_type.unsqueeze(-1)).squeeze() + mention_loss = mention_loss * mask.float() + # mention_loss = sequence_cross_entropy_with_logits(logits, mention_type, mask, + # average='token') + # if not self.training: self._new_mention_f1(predictions=logits, gold_labels=mention_type, @@ -432,11 +432,11 @@ def _mention_type_loss(self, gold_labels=mention_type, mask=mask) - return mention_type_loss + return mention_loss.sum(-1) def _new_entity_logits(self, encoded: torch.Tensor, - shortlist: torch.Tensor) -> torch.Tensor: + shortlist: torch.Tensor = None) -> torch.Tensor: if self._use_shortlist: # Embed the shortlist entries shortlist_embeddings = embedded_dropout( @@ -468,23 +468,17 @@ def _new_entity_loss(self, shortlist_mask = get_text_field_mask(shortlist) log_probs = masked_log_softmax(logits, shortlist_mask) else: - logits = logits log_probs = F.log_softmax(logits, dim=-1) - num_categories = log_probs.shape[-1] - log_probs = log_probs.view(-1, num_categories) - target_inds = target_inds.view(-1) - target_log_probs = torch.gather(log_probs, -1, target_inds.unsqueeze(-1)).squeeze(-1) - - mask = ~target_inds.eq(0) - target_log_probs[~mask] = 0 + loss = -log_probs.gather(-1, target_inds.unsqueeze(-1)).squeeze(-1) + loss = loss * target_mask.float() - if mask.any(): - self._new_entity_accuracy(predictions=log_probs[mask], - gold_labels=target_inds[mask]) - self._new_entity_accuracy20(predictions=log_probs[mask], - gold_labels=target_inds[mask]) + # if target_mask.any(): + # self._new_entity_accuracy(predictions=log_probs[target_mask], + # gold_labels=target_inds[target_mask]) + # self._new_entity_accuracy20(predictions=log_probs[target_mask], + # gold_labels=target_inds[target_mask]) - return -target_log_probs.sum() / (target_mask.sum() + 1e-13) + return loss.sum(-1) # / (target_mask.sum() + 1e-13) def _parent_log_probs(self, encoded_head: torch.Tensor, @@ -596,7 +590,7 @@ def _knowledge_graph_entity_loss(self, self._parent_ppl(-torch.logsumexp(parent_log_probs, dim=-1)[mask].sum(), mask.float().sum()) self._relation_ppl(-torch.logsumexp(relation_log_probs, dim=-1)[mask].sum(), mask.float().sum()) # Lastly return the tokenwise average loss - return -target_log_probs.sum() / (target_mask.sum() + 1e-13) + return -target_log_probs.sum(-1) # / (target_mask.sum() + 1e-13) def _forward_loop(self, source: Dict[str, torch.Tensor], @@ -629,7 +623,7 @@ def _forward_loop(self, # Predict whether or not the next token will be an entity mention, and if so which type. mention_type_loss = self._mention_type_loss(encoded_token, mention_type, target_mask) - self._avg_mention_type_loss(float(mention_type_loss)) + self._avg_mention_type_loss(float(mention_type_loss.sum()/target_mask.sum())) # For new mentions, predict which entity (among those in the supplied shortlist) will be # mentioned. @@ -643,8 +637,9 @@ def _forward_loop(self, entity_ids, None, target_mask) + logger.debug('new_entity_loss: %s', new_entity_loss) - self._avg_new_entity_loss(float(new_entity_loss)) + self._avg_new_entity_loss(float(new_entity_loss.sum()/target_mask.sum())) # For derived mentions, first predict which parent(s) to expand... knowledge_graph_entity_loss = self._knowledge_graph_entity_loss(encoded_head, @@ -653,10 +648,10 @@ def _forward_loop(self, entity_ids, parent_ids, target_mask) - self._avg_knowledge_graph_entity_loss(float(knowledge_graph_entity_loss)) + self._avg_knowledge_graph_entity_loss(float(knowledge_graph_entity_loss.sum()/target_mask.sum())) # Compute total loss - loss = mention_type_loss + new_entity_loss + knowledge_graph_entity_loss + loss = (mention_type_loss + new_entity_loss + knowledge_graph_entity_loss).sum() / target_mask.sum() # Activation regularization if self._alpha: @@ -667,20 +662,467 @@ def _forward_loop(self, return {'loss': loss} + def _next_mention_type_logp(self, next_mention_type_logits, beam_states): + """ + Computes log probabilities of mention type for next token, .e.g, adjusts logits to prevent ongoing non-mentions. + Intended for use when performing beam search. + + Parameters + ========== + next_mention_type_logits: torch.FloatTensor + Tensor of shape (batch_size, num_mention_types) containing next mention type logits. + beam_states: List[KglmBeamState] + List of previous beam states. + + Returns + ======= + next_mention_type_logp: + Tensor of shape (batch_size, beam_width, num_mention_types) containing next mention type log probabilities. + """ + beam_width = len(beam_states) + + # Tile the mention_logits, and apply penalty to non-ongoing mentions + out = next_mention_type_logits.unsqueeze(1).repeat(1, beam_width, 1) + for i, beam_state in enumerate(beam_states): + out[~beam_state.ongoing, i, -1] = -1e32 + + return F.log_softmax(out, dim=-1) + + def _next_new_entity_logp(self, next_new_entity_logits, beam_states): + """ + Computes log probabilities of new entity mentions. + Intended for use when performing beam search. + + Parameters + ========== + next_new_entity_logits: torch.FloatTensor + Tensor of shape (batch_size, num_entities) containing next new entity logits. + beam_states: List[KglmBeamState] + List of previous beam states. + + Returns + ======= + next_new_entity_logp: + Tensor of shape (batch_size, beam_width, num_mention_types) containing next new entity log probabilities. + """ + beam_width = len(beam_states) + # Tile the mention_logits, and apply penalty to non-ongoing mentions + out = next_new_entity_logits.unsqueeze(1).repeat(1, beam_width, 1) + for j, beam_state in enumerate(beam_states): + self._recent_entities.load_beam_state(beam_state.recent_entities) + for i, recent_ids in enumerate(self._recent_entities._remaining): + for recent_id in recent_ids: + out[i, j, recent_id] = -1e32 + return F.log_softmax(out, dim=-1) + + def _next_related_entity_logp(self, next_encoded_head, next_encoded_relation, beam_states): + """ + Computes log probabilities of related entity mentions. + Intended for use when performing beam search. + + Parameters + ========== + next_encoded_head: torch.FloatTensor + Tensor of shape (batch_size, embedding_dim) of the head encodings. + next_encoded_relation: torch.FloatTensor + Tensor of shape (batch_size, embedding_dim) of the relation encodings. + beam_states: List[KglmBeamState] + List of previous beam states. + + Returns + ======= + logp: + Tensor of shape (batch_size, beam_width, num_candidates) containing the log + probability of the parent/relation combination. + And a dictionary containing the annotation data. + parent_ids: + Tensor of shape (batch_size, beam_width, num_candidates) + relation_ids: + Tensor of shape (batch_size, beam_width, num_candidates) + raw_entity_ids: + Tensor of shape (batch_size, beam_width, num_candidates) + """ + batch_size = next_encoded_head.size(0) + beam_width = len(beam_states) + logp_arr = np.empty((batch_size, beam_width), dtype=object) + parent_ids_arr = np.empty((batch_size, beam_width), dtype=object) + relations_arr = np.empty((batch_size, beam_width), dtype=object) + raw_entity_ids_arr = np.empty((batch_size, beam_width), dtype=object) + for j, beam_state in enumerate(beam_states): + # Get the set of candidate parents from the RecentEntities module. + # Since we are only considering candidates for a single timestep we can get the parents + # directly from the RecentEntities._remaining dictionaries' keys. + self._recent_entities.load_beam_state(beam_state.recent_entities) + for i, candidate_ids in enumerate(self._recent_entities._remaining): + # Cast candidate ids to a tensor, lookup embeddings, and compute score. + candidate_ids = torch.tensor(list(candidate_ids.keys()), + dtype=torch.int64, + device=next_encoded_head.device) + candidate_embeddings = self._entity_embedder(candidate_ids) + candidate_logits = torch.mv(candidate_embeddings, next_encoded_head[i]) + candidate_logp = F.log_softmax(candidate_logits) + + # Lookup relations + _, s, r, o = self._knowledge_graph_lookup(candidate_ids) + relation_embeddings_list = [self._relation_embedder(_r) for _r in r] + + # Stop early if node is isolated + if not s: + logp_arr[i, j] = torch.tensor([], dtype=torch.float32, device=next_encoded_head.device) + parent_ids_arr[i, j] = torch.tensor([], dtype=torch.int64, device=next_encoded_head.device) + relations_arr[i, j] = torch.tensor([], dtype=torch.int64, device=next_encoded_head.device) + raw_entity_ids_arr[i, j] = torch.tensor([], dtype=torch.int64, device=next_encoded_head.device) + continue + + # Otherwise compute relation probabilities for each parent and combine + temp_logp = [] + temp_parent_ids = [] + temp_relations = [] + temp_raw_entity_ids = [] + for idx, relation_embeddings in enumerate(relation_embeddings_list): + num_relations = relation_embeddings.size(0) + relation_logits = torch.mv(relation_embeddings, next_encoded_relation[i]) + relation_logp = F.log_softmax(relation_logits) + temp_logp.append(candidate_logp[idx] + relation_logp) + temp_parent_ids.append(s[idx].repeat(num_relations)) + temp_relations.append(r[idx]) + temp_raw_entity_ids.append(o[idx]) + logp_arr[i, j] = torch.cat(temp_logp) + parent_ids_arr[i, j] = torch.cat(temp_parent_ids) + relations_arr[i, j] = torch.cat(temp_relations) + raw_entity_ids_arr[i, j] = torch.cat(temp_raw_entity_ids) + + num_candidates = max(t.size(0) for t in logp_arr.flatten()) + logp = next_encoded_head.new_full((batch_size, beam_width, num_candidates), -1e32) + parent_ids = next_encoded_head.new_zeros((batch_size, beam_width, num_candidates), dtype=torch.int64) + relations = next_encoded_head.new_zeros((batch_size, beam_width, num_candidates), dtype=torch.int64) + raw_entity_ids = next_encoded_head.new_zeros((batch_size, beam_width, num_candidates), dtype=torch.int64) + for i in range(batch_size): + for j in range(beam_width): + size = logp_arr[i][j].size(0) + logp[i, j, :size] = logp_arr[i][j] + parent_ids[i, j, :size] = parent_ids_arr[i][j] + relations[i, j, :size] = relations_arr[i][j] + raw_entity_ids[i ,j, :size] = raw_entity_ids_arr[i][j] + + annotations = { + 'parent_ids': parent_ids, + 'relations': relations, + 'raw_entity_ids': raw_entity_ids + } + + return logp, annotations + + def _top_k_annotations(self, + next_mention_type_logp, + next_new_entity_logp, + next_related_entity_logp, + related_entity_annotations, + output, + k): + """ + Aggregate log probabilities and return top-k results. + + Don't be intimidated by the amount of code - almost all of it relates to various + bookkeeping tasks to get the annotations. + """ + # === Bookkeeping ==== + # Need to get all of the relevant sizes + batch_size, beam_width, n_new = next_new_entity_logp.size() + n_related = next_related_entity_logp.size(-1) + + # Derive the length of the full tensor: # new + # related + ongoing + unrelated + length = n_new + n_related + 2 + total_logp = next_mention_type_logp.new_empty(batch_size, beam_width, length) + + # For clarity, name the slices + new_slice = slice(0, n_new) + related_slice = slice(n_new, n_new + n_related) + ongoing_slice = -2 + null_slice = -1 + + # === Annotation lookups === + mention_type_lookup = torch.zeros_like(total_logp, dtype=torch.int64) + parent_id_lookup = torch.zeros_like(total_logp, dtype=torch.int64) + relation_lookup = torch.zeros_like(total_logp, dtype=torch.int64) + raw_entity_id_lookup = torch.zeros_like(total_logp, dtype=torch.int64) + entity_id_lookup = torch.zeros_like(total_logp, dtype=torch.int64) + + # Mention type + mention_type_lookup[:, :, new_slice] = 1 + mention_type_lookup[:, :, related_slice] = 2 + mention_type_lookup[:, :, ongoing_slice] = 3 + mention_type_lookup[:, :, null_slice] = 0 + + # New + id_range = torch.arange(n_new, device=entity_id_lookup.device).view(1, 1, n_new) + entity_id_lookup[:, :, new_slice] = id_range + raw_entity_id_lookup[:, :, new_slice] = self.get_raw_entity_ids(id_range) + + # Related + parent_id_lookup[:, :, related_slice] = related_entity_annotations['parent_ids'] + relation_lookup[:, :, related_slice] = related_entity_annotations['relations'] + raw_entity_id_lookup[:, :, related_slice] = related_entity_annotations['raw_entity_ids'] + entity_id_lookup[:, :, related_slice] = self.get_entity_ids(related_entity_annotations['raw_entity_ids']) + + # Ongoing + if output is not None: + parent_id_lookup[:, :, ongoing_slice] = output['parent_ids'] + relation_lookup[:, :, ongoing_slice] = output['relations'] + entity_id_lookup[:, :, ongoing_slice] = output['entity_ids'] + raw_entity_id_lookup[:, :, ongoing_slice] = output['raw_entity_ids'] + + # === Logp === + + # Set the mention probabilities + total_logp[:, :, new_slice] = next_mention_type_logp[:, :, 1].unsqueeze(-1) + total_logp[:, :, related_slice] = next_mention_type_logp[:, :, 2].unsqueeze(-1) + total_logp[:, :, ongoing_slice] = next_mention_type_logp[:, :, 3] + total_logp[:, :, null_slice] = next_mention_type_logp[:, :, 0] + + # Add the entity probabilities + total_logp[:, :, new_slice] += next_new_entity_logp + total_logp[:, :, related_slice] += next_related_entity_logp + + # If available add the previous beam probabilities + if output is not None: + total_logp += output['logp'].unsqueeze(-1) + + # Get the top-k outputs + top_logp, top_indices = total_logp.view(batch_size, -1).topk(k, dim=-1) + output = { + 'logp': top_logp, + 'backpointers': top_indices // length, + 'mention_types': mention_type_lookup.view(batch_size, -1).gather(-1, top_indices), + 'parent_ids': parent_id_lookup.view(batch_size, -1).gather(-1, top_indices), + 'relations': relation_lookup.view(batch_size, -1).gather(-1, top_indices), + 'entity_ids': entity_id_lookup.view(batch_size, -1).gather(-1, top_indices), + 'raw_entity_ids': raw_entity_id_lookup.view(batch_size, -1).gather(-1, top_indices) + } + return output + + def _update_beam_states(self, output, beam_states): + """ + Ensure that the correct recent entities modules and ongoing flags are properly taken from + the last step and updated using the current predicted outputs. + """ + new_beam_states = [] + backpointers = output['backpointers'] + batch_size, beam_width = backpointers.size() + # To facilitate indexing with the backpointers, we'll store the RecentEntities' _remaining + # dicts in a numpy array. + remaining_dicts = np.empty((batch_size, len(beam_states)), dtype=object) + for j, beam_state in enumerate(beam_states): + self._recent_entities.load_beam_state(beam_state.recent_entities) + for i in range(batch_size): + remaining_dicts[i, j] = self._recent_entities._remaining[i] + + for i in range(beam_width): + # Everything but null mention types can be ongoing in next step. + ongoing = output['mention_types'][:, i] != 0 + + # Trace backpointers to retrieve correct recent entities dicts, and update using the + # current output. + bp = backpointers[:, i].cpu().numpy() + remaining = remaining_dicts[np.arange(batch_size), bp].tolist() + self._recent_entities.load_beam_state({'remaining': remaining}) + self._recent_entities(output['entity_ids'][:, i].unsqueeze(-1)) + + # Add beam states + new_beam_states.append( + KglmBeamState(recent_entities=self._recent_entities.beam_state(), + ongoing=ongoing) + ) + + return new_beam_states + + def _to_raw_entity_tokens(self, x): + """ + Returns the raw entity id strings for a nested list of raw entity ids + """ + if isinstance(x, list): + return [self._to_raw_entity_tokens(i) for i in x] + elif isinstance(x, int): + return self.vocab.get_token_from_index(x, 'raw_entity_ids') + else: + return ValueError('Expecting a nested list of raw entity ids') + + def _trace_backpointers(self, + source, + target, + reset, + metadata, + k, + predictions): + """ + Traces backpointers to collect the top-k annotations. + """ + batch_size, seq_length = source['tokens'].shape + alias_database = metadata[0]['alias_database'] + + new_source = {key: value.unsqueeze(1).repeat(1, k, 1).view(batch_size * k, -1) for key, value in source.items()} + new_target = {key: value.unsqueeze(1).repeat(1, k, 1).view(batch_size * k, -1) for key, value in target.items()} + new_reset = reset.unsqueeze(1).repeat(1, k).view(batch_size * k) + new_metadata = [metadata[i] for i in range(batch_size) for _ in range(k)] + + mention_types = [] + parent_ids = [] + relations = [] + raw_entity_ids = [] + entity_ids = [] + + backpointer = None + + for prediction in reversed(predictions): + if backpointer is None: + mention_types.append(prediction['mention_types']) + parent_ids.append(prediction['parent_ids']) + relations.append(prediction['relations']) + raw_entity_ids.append(prediction['raw_entity_ids']) + entity_ids.append(prediction['entity_ids']) + else: + mention_types.append(prediction['mention_types'].gather(1, backpointer)) + parent_ids.append(prediction['parent_ids'].gather(1, backpointer)) + relations.append(prediction['relations'].gather(1, backpointer)) + raw_entity_ids.append(prediction['raw_entity_ids'].gather(1, backpointer)) + entity_ids.append(prediction['entity_ids'].gather(1, backpointer)) + if backpointer is None: + backpointer = prediction['backpointers'] + else: + backpointer = prediction['backpointers'].gather(1, backpointer) + + mention_types = torch.stack(mention_types[::-1], dim=-1).view(batch_size * k, -1) + parent_ids = torch.stack(parent_ids[::-1], dim=-1).view(batch_size * k, -1) + relations = torch.stack(relations[::-1], dim=-1).view(batch_size * k, -1) + raw_entity_ids = torch.stack(raw_entity_ids[::-1], dim=-1).view(batch_size * k, -1) + entity_ids = torch.stack(entity_ids[::-1], dim=-1).view(batch_size * k, -1) + + # One final bit of complexity - we need to get copy indices. + raw_entity_tokens = self._to_raw_entity_tokens(raw_entity_ids.tolist()) + target_tokens = [x['target_tokens'] for x in new_metadata] + alias_copy_inds_list = alias_database.nested_token_to_uid(raw_entity_tokens, target_tokens) + alias_copy_inds = torch.tensor(alias_copy_inds_list, device=mention_types.device) + + return { + 'source': new_source, + 'target': new_target, + 'reset': new_reset, + 'metadata': new_metadata, + 'mention_types': mention_types, + 'parent_ids': parent_ids, + 'relations': relations, + 'raw_entity_ids': raw_entity_ids, + 'entity_ids': entity_ids, + 'alias_copy_inds': alias_copy_inds + } + + def beam_search(self, + source: Dict[str, torch.Tensor], + target: Dict[str, torch.Tensor], + reset: torch.ByteTensor, + metadata: Dict[str, Any], + k: int, + **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Obtain the top-k (approximately) most likely predictions from the model using beam + search. Unlike typical beam search all of the beam states are returned instead of just + the most likely. + + The returned candidates are intended to be marginalized over to obtain an upper bound for + the token-level perplexity of the EntityNLM. + + Parameters + ========== + source : ``Dict[str, torch.Tensor]`` + A tensor of shape ``(batch_size, sequence_length)`` containing the sequence of + tokens. + reset : ``torch.ByteTensor`` + Whether or not to reset the model's state. This should be done at the start of each + new sequence. + metadata : ``Dict[str, Any]`` + Assorted metadata. Should contain the alias database, as well as the token strings (needed to retrieve copy indices). + k : ``int`` + Number of predictions to return. + + Returns + ======= + predictions : ``torch.Tensor`` + A tensor of shape ``(batch_size * k, sequence_length)`` containing the top-k + predictions. + logp : ``torch.Tensor`` + The log-probabilities of each prediction. WARNING: These are returned purely for + diagnostic purposes and should not be factored in the the perplexity calculation. + """ + # We want the output fields to be properly aligned for the generative model, which makes + # predictions for the **target** tokens! Hence, we feed them as the input (instead of the + # source tokens). + batch_size, sequence_length = target['tokens'].shape + + # Reset the model's internal state. + if not reset.all(): + raise RuntimeError('Detecting that not all states are being `reset` (e.g., that input ' + 'sequences have been split). Cannot predict top-K annotations in ' + 'this setting!') + self.reset_states(reset) + + # The following tensors can be computed using only the encoder: + # * The 3-headed encodings. + # * The (unconstrained) mention type logits. + # * The (unconstrained) new entity logits. + # Although we can compute the mention type and new entity logits, we will need to compute + # the log-probabilities during decoding due to the following constraints: + # * `mention_type` = CONTINUE only if the previous token type was a new or ongoing mention. + # * `new_entity` cannot be in recent entities. + encoded, *_ = self._encode_source(target['tokens']) + splits = [self.token_embedding_dim] + [self.entity_embedding_dim] * 2 + encoded_token, encoded_head, encoded_relation = encoded.split(splits, dim=-1) + mention_type_logits = self._fc_mention_type(encoded_token) + new_entity_logits = self._new_entity_logits(encoded_head + encoded_relation) + + # Beam search logic + predictions: List[Dict[str, torch.Tensor]] = [] + beam_states = [KglmBeamState(recent_entities=self._recent_entities.beam_state(), + ongoing=torch.zeros_like(reset))] + output = None + + for timestep in range(sequence_length): + logger.debug(timestep) + # Get log probabilities of all next states + next_mention_type_logp = self._next_mention_type_logp(mention_type_logits[:, timestep], + beam_states) + next_new_entity_logp = self._next_new_entity_logp(new_entity_logits[:, timestep], + beam_states) + next_related_entity_logp, related_entity_annotations = self._next_related_entity_logp( + encoded_head[:, timestep], + encoded_relation[:, timestep], + beam_states) + + output = self._top_k_annotations(next_mention_type_logp, + next_new_entity_logp, + next_related_entity_logp, + related_entity_annotations, + output, + k) + beam_states = self._update_beam_states(output, beam_states) + predictions.append(output) + + annotation = self._trace_backpointers(source, target, reset, metadata, k, predictions) + + return annotation + @overrides def train(self, mode=True): - # TODO: This is a temporary hack to ensure that the internal state resets when the model - # switches from training to evaluation. The complication arises from potentially differing - # batch sizes (e.g. the `reset` tensor will not be the right size). In future - # implementations this should be handled more robustly. + # This is a hack to ensure that the internal state resets when the model switches from + # training to evaluation. The complication arises from potentially differing batch sizes + # (e.g. the `reset` tensor will not be the right size). super().train(mode) - self._state = None + self._rnn.reset() @overrides def eval(self): - # TODO: See train. + # See train. super().eval() - self._state = None + self._rnn.reset() def get_metrics(self, reset: bool = False) -> Dict[str, float]: out = { @@ -702,3 +1144,7 @@ def get_metrics(self, reset: bool = False) -> Dict[str, float]: out['parent_ppl'] = self._parent_ppl.get_metric(reset) out['relation_ppl'] = self._relation_ppl.get_metric(reset) return out + + def reset_states(self, reset): + self._rnn.reset(reset) + self._recent_entities.reset(reset) diff --git a/kglm/modules/__init__.py b/kglm/modules/__init__.py index 13ab4d2..9aa110a 100644 --- a/kglm/modules/__init__.py +++ b/kglm/modules/__init__.py @@ -4,4 +4,4 @@ from .locked_dropout import LockedDropout from .recent_entities import RecentEntities from .splitcross import SplitCrossEntropyLoss -from .weight_drop import WeightDrop +from .weight_drop import WeightDrop, WeightDroppedLstm diff --git a/kglm/modules/dynamic_embeddings.py b/kglm/modules/dynamic_embeddings.py index 3578fd3..d2ac131 100644 --- a/kglm/modules/dynamic_embeddings.py +++ b/kglm/modules/dynamic_embeddings.py @@ -1,5 +1,6 @@ -from typing import Dict, Optional +from copy import deepcopy import logging +from typing import Dict, Optional from overrides import overrides import torch @@ -30,14 +31,16 @@ class DynamicEmbedding(Module): """ def __init__(self, embedding_dim: int, - max_embeddings: int) -> None: + max_embeddings: int, + tied_weight: Optional[torch.nn.Parameter]) -> None: super(DynamicEmbedding, self).__init__() self._embedding_dim = embedding_dim self._max_embeddings = max_embeddings - self._initial_embedding = Parameter(F.normalize(torch.randn(embedding_dim), dim=0)) + self._initial_embedding = tied_weight + # self._initial_embedding = Parameter(F.normalize(torch.randn(embedding_dim), dim=0)) - self._distance_scalar = Parameter(torch.tensor(1e-6)) # pylint: disable=E1102 + self._distance_scalar = Parameter(torch.tensor(1e-3)) # pylint: disable=E1102 self._embedding_projection = torch.nn.Linear(in_features=embedding_dim, out_features=embedding_dim, bias=False) @@ -68,10 +71,10 @@ def reset_states(self, reset: torch.ByteTensor) -> None: # This simplifies the case where the batch_size has been if reset.all(): self.embeddings = self._initial_embedding.new_zeros(batch_size, self._max_embeddings, - self._embedding_dim) + self._embedding_dim) self.num_embeddings = self._initial_embedding.new_zeros(batch_size, dtype=torch.int64) self.last_seen = self._initial_embedding.new_zeros(batch_size, self._max_embeddings, - dtype=torch.int64) + dtype=torch.int64) else: self.embeddings[reset] = 0 self.num_embeddings[reset] = 0 @@ -87,7 +90,7 @@ def detach_states(self) -> None: self.embeddings = self.embeddings.detach() def add_embeddings(self, - timestep: int, + timestep: torch.LongTensor, mask: Optional[torch.Tensor] = None) -> None: """ Adds new embeddings to the current collection of embeddings. @@ -109,7 +112,7 @@ def add_embeddings(self, # Embeddings are initialized by adding a small amount of random noise to the initial # embedding tensor then normalizing. - initial = self._initial_embedding.repeat((mask.sum(), 1, 1)) + initial = self._initial_embedding[1].repeat((mask.sum(), 1, 1)) noise = 1e-4 * torch.randn_like(initial) # 1e-4 is a magic number from the original implementation unnormalized = initial + noise normalized = F.normalize(unnormalized, dim=-1) @@ -121,11 +124,10 @@ def add_embeddings(self, if self.num_embeddings.max() == (self._max_embeddings - 1): logger.warning('Embeddings full') - def update_embeddings(self, hidden: torch.Tensor, update_indices: torch.Tensor, - timestep: int, + timestep: torch.LongTensor, mask: Optional[torch.Tensor] = None) -> None: """ Updates existing embeddings. @@ -169,7 +171,7 @@ def update_embeddings(self, # dimension when accessing self.embeddings. Accordingly, the batch dimension of # normalized needs to be dropped in this case in order for assignment to work. self.embeddings[mask, update_indices[mask]] = normalized.squeeze(0) - self.last_seen[mask, update_indices[mask]] = timestep + self.last_seen[mask, update_indices[mask]] = timestep[mask] @overrides def forward(self, # pylint: disable=arguments-differ @@ -218,7 +220,9 @@ def forward(self, # pylint: disable=arguments-differ bilinear = bilinear.view(batch_size, -1) # Second half of equation 4. - distance_score = torch.exp(self._distance_scalar * (self.last_seen[mask].float() - timestep)) + distance_score = self._distance_scalar * (timestep[mask].float().unsqueeze(-1) - self.last_seen[mask].float()) + assert not (self.last_seen[mask] - timestep[mask].unsqueeze(-1)).gt(0).any() + logits = bilinear + distance_score # Since we pre-allocate the embedding array, logits includes scores for all of the @@ -227,11 +231,12 @@ def forward(self, # pylint: disable=arguments-differ num_embeddings = self.num_embeddings[mask].unsqueeze(1) arange = torch.arange(self._max_embeddings, device=num_embeddings.device).repeat(mask.sum(), 1) logit_mask = arange.lt(num_embeddings) - logits[logit_mask != 1] = 1e-34 + logits[logit_mask != 1] = -1e34 out = { 'logits': logits, - 'logit_mask': logit_mask + 'logit_mask': logit_mask, + 'logp': F.log_softmax(logits, -1) } if target is not None: @@ -242,3 +247,16 @@ def forward(self, # pylint: disable=arguments-differ return out + def beam_state(self): + beam_state = { + 'embeddings': self.embeddings.detach(), + 'num_embeddings': self.num_embeddings.detach(), + 'last_seen': self.last_seen.detach() + } + return beam_state + + def load_beam_state(self, beam_state): + self.embeddings = beam_state.get('embeddings', None) + self.num_embeddings = beam_state.get('num_embeddings', None) + self.last_seen = beam_state.get('last_seen', None) + diff --git a/kglm/modules/recent_entities.py b/kglm/modules/recent_entities.py index 00eb0eb..d4cfe05 100644 --- a/kglm/modules/recent_entities.py +++ b/kglm/modules/recent_entities.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import Dict, List, Tuple from allennlp.modules.token_embedders import TokenEmbedder @@ -72,7 +73,6 @@ def __call__(self, k = candidate_lookup[i][parent_id] candidate_mask[i, j + 1 : j + self._cutoff + 1, k] = 1 # Track how many sequence elements remain - remainder = sequence_length - (j + self._cutoff + 1) self._remaining[i][parent_id] = (j + self._cutoff + 1) - sequence_length # Remove any ids for non-recent parents (e.g. those without remaining mask) @@ -156,3 +156,9 @@ def insert(self, values: torch.LongTensor, mask: torch.ByteTensor = None) -> Non else: self._remaining[i][values[i].item()] = self._cutoff + 1 + def beam_state(self): + beam_state = {'remaining': self._remaining} + return deepcopy(beam_state) + + def load_beam_state(self, beam_state): + self._remaining = beam_state.get('remaining', []) diff --git a/kglm/modules/weight_drop.py b/kglm/modules/weight_drop.py index 0370711..0bfef8a 100644 --- a/kglm/modules/weight_drop.py +++ b/kglm/modules/weight_drop.py @@ -1,10 +1,21 @@ +import logging +from typing import Dict, List, Optional, Tuple + +from overrides import overrides import torch from torch.nn import Parameter import torch.nn.functional as F +logger = logging.getLogger(__name__) + + +LstmState = Tuple[torch.FloatTensor, torch.FloatTensor] +StateDict = Dict[str, LstmState] + class WeightDrop(torch.nn.Module): "A module that warps another layer in which some weights will be replaced by 0 during training." + # pylint: disable=protected-access def __init__(self, module, weights, dropout=0): super().__init__() @@ -13,9 +24,9 @@ def __init__(self, module, weights, dropout=0): self.dropout = dropout for weight in self.weights: #Makes a copy of the weights of the selected layers. - w = getattr(self.module, weight) - self.register_parameter(f'{weight}_raw', Parameter(w.data)) - self.module._parameters[weight] = F.dropout(w, p=self.dropout, training=False) + raw_w = getattr(self.module, weight) + self.register_parameter(f'{weight}_raw', Parameter(raw_w.data)) + self.module._parameters[weight] = F.dropout(raw_w, p=self.dropout, training=False) def _setweights(self): "Apply dropout to the raw weights." @@ -31,5 +42,78 @@ def reset(self): for weight in self.weights: raw_w = getattr(self, f'{weight}_raw') self.module._parameters[weight] = F.dropout(raw_w, p=self.dropout, training=False) - if hasattr(self.module, 'reset'): self.module.reset() + if hasattr(self.module, 'reset'): + self.module.reset() + + +class WeightDroppedLstm(torch.nn.Module): + def __init__(self, + num_layers: int, + input_embedding_dim: int, + hidden_size: int, + output_embedding_dim: Optional[int] = None, + dropout: Optional[float] = 0.0) -> None: + super().__init__() + + self._num_layers = num_layers + self._input_embedding_dim = input_embedding_dim + self._hidden_size = hidden_size + if output_embedding_dim is not None: + self._output_embedding_dim = output_embedding_dim + else: + self._output_embedding_dim = input_embedding_dim + self._dropout = dropout + self._state: Optional[StateDict] = None + + # Create an LSTM for each layer and apply weight drop. + rnns: List[torch.nn.Module] = [] + for i in range(num_layers): + if i == 0: + input_size = self._input_embedding_dim + else: + input_size = self._hidden_size + if i == num_layers - 1: + output_size = self._output_embedding_dim + else: + output_size = hidden_size + rnns.append(torch.nn.LSTM(input_size, output_size, batch_first=True)) + rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=dropout) for rnn in rnns] + + self._rnns = torch.nn.ModuleList(rnns) + + @overrides + def forward(self, embeddings: torch.FloatTensor) -> torch.FloatTensor: # pylint: disable=arguments-differ + current_input = embeddings + hidden_list = [] + for layer, rnn in enumerate(self._rnns): + # Retrieve previous hidden state for layer. Weird syntax in order to appease MyPy. + prev_hidden: Optional[LstmState] = None + if self._state is not None: + prev_hidden = self._state['layer_%i' % layer] + # Forward-pass. + output, hidden = rnn(current_input, prev_hidden) + output = output.contiguous() + # Update hidden state for layer. + hidden = tuple(h.detach() for h in hidden) + hidden_list.append(hidden) + current_input = output + self._state = {'layer_%i' % i: h for i, h in enumerate(hidden_list)} + return current_input + + def reset(self, reset: torch.ByteTensor = None) -> None: + """Resets the internal hidden states""" + # pylint: disable=invalid-name + if reset is None: + logger.debug('Fully resetting LSTM state') + self._state = None + elif reset.all(): + logger.debug('Fully resetting LSTM state') + self._state = None + if self._state is None: + return + for layer in range(self._num_layers): + h, c = self._state['layer_%i' % layer] + h[:, reset, :] = torch.zeros_like(h[:, reset, :]) + c[:, reset, :] = torch.zeros_like(c[:, reset, :]) + self._state['layer_%i' % layer] = (h, c) diff --git a/kglm/run.py b/kglm/run.py index 708b16a..e961cba 100644 --- a/kglm/run.py +++ b/kglm/run.py @@ -16,10 +16,12 @@ from allennlp.commands import main from kglm.commands import EvaluatePerplexity from kglm.commands import CompleteTheSentence +from kglm.commands import BeamSum if __name__ == "__main__": main(prog="allennlp", subcommand_overrides={ 'evaluate-perplexity': EvaluatePerplexity(), 'complete-the-sentence': CompleteTheSentence(), + 'beam-sum': BeamSum() }) diff --git a/kglm/tests/dataset_readers/conll2012_test.py b/kglm/tests/dataset_readers/conll2012_test.py index 9a8156b..b06573a 100644 --- a/kglm/tests/dataset_readers/conll2012_test.py +++ b/kglm/tests/dataset_readers/conll2012_test.py @@ -54,11 +54,11 @@ def test_read_from_file(self, lazy): 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) np.testing.assert_allclose(instances[0]["mention_lengths"].array, - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, - 1, 1, 1, 6, 5, 4, 3, 2, 1, 1, 1, 1, 2, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1]) + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0]) class TestConll2012JsonlReader: @pytest.mark.parametrize('lazy', (True, False)) @@ -70,20 +70,16 @@ def test_read_from_file(self, lazy, offset): assert len(instances) == 2 first_instance_tokens = [x.text for x in instances[0]['source'].tokens] - assert first_instance_tokens[:5] == ["@@START@@", "Jesus", "left", "and", "went"] - assert first_instance_tokens[-5:] == [ "long", "ago", ".", "''", "@@END@@"] + assert first_instance_tokens[:5] == ["@@START@@", "in", "the", "summer", "of"] + assert first_instance_tokens[-5:] == [ "mainland", "china", "tourist", "market", "@@END@@"] second_instance_entity_ids = instances[1]['entity_ids'].array second_instance_mention_lengths = instances[1]['mention_lengths'].array second_instance_entity_types = instances[1]['entity_types'].array - np.testing.assert_allclose(second_instance_entity_types[(1 - offset):(3 - offset)], - np.array([1,0], dtype=np.uint8)) - np.testing.assert_allclose(second_instance_entity_ids[(1 - offset):(2 - offset)], - np.array([1], dtype=np.int64)) - np.testing.assert_allclose(second_instance_entity_ids[(8 - offset):(9 - offset)], - np.array([1], dtype=np.int64)) - np.testing.assert_allclose(second_instance_entity_ids[(30 - offset):(32 - offset)], + np.testing.assert_allclose(second_instance_entity_types[(9 - offset):(13 - offset)], + np.array([0,1,1,0], dtype=np.uint8)) + np.testing.assert_allclose(second_instance_entity_ids[(10 - offset):(12 - offset)], np.array([1, 1], dtype=np.int64)) - np.testing.assert_allclose(second_instance_mention_lengths[(30 - offset):(32 - offset)], - np.array([2, 1], dtype=np.int64)) \ No newline at end of file + np.testing.assert_allclose(second_instance_mention_lengths[(10 - offset):(12 - offset)], + np.array([1, 0], dtype=np.int64)) diff --git a/kglm/tests/dataset_readers/enhanced_wikitext_test.py b/kglm/tests/dataset_readers/enhanced_wikitext_test.py index a3e0c9a..c284673 100644 --- a/kglm/tests/dataset_readers/enhanced_wikitext_test.py +++ b/kglm/tests/dataset_readers/enhanced_wikitext_test.py @@ -23,9 +23,9 @@ def test_read_from_file(self, lazy): np.testing.assert_allclose(instances[1]["entity_ids"].array[:5], [0, 0, 1, 1, 1]) np.testing.assert_allclose(instances[1]["entity_ids"].array[-5:], [0, 0, 0, 0, 0]) np.testing.assert_allclose(instances[1]["mention_lengths"].array[:5], - [1, 1, 5, 4, 3]) + [0, 0, 4, 3, 2]) np.testing.assert_allclose(instances[1]["mention_lengths"].array[-5:], - [1, 1, 1, 1, 1]) + [0, 0, 0, 0, 0]) class TestEnhancedWikitextKglmReader: diff --git a/kglm/tests/fixtures/conll2012.jsonl b/kglm/tests/fixtures/conll2012.jsonl index 3d5e39c..a0bc1fa 100644 --- a/kglm/tests/fixtures/conll2012.jsonl +++ b/kglm/tests/fixtures/conll2012.jsonl @@ -1,2 +1,2 @@ -{"tokens": ["Jesus", "left", "and", "went", "back", "to", "his", "hometown", ".", "His", "followers", "went", "with", "him", ".", "On", "the", "Sabbath", "day", "Jesus", "taught", "in", "the", "synagogue", ",", "and", "many", "people", "heard", "him", ".", "They", "were", "amazed", "and", "said", ",", "``", "Where", "did", "this", "man", "get", "this", "teaching", "?", "How", "did", "he", "get", "such", "wisdom", "?", "Who", "gave", "it", "to", "him", "?", "And", "where", "did", "he", "get", "the", "power", "to", "do", "miracles", "?", "Is", "n't", "he", "just", "the", "carpenter", "we", "know", "--", "Mary", "'s", "son", ",", "the", "brother", "of", "James", ",", "Joses", ",", "Judas", ",", "and", "Simon", "?", "And", "do", "n't", "his", "sisters", "still", "live", "here", "in", "town", "?", "''", "So", "they", "had", "a", "problem", "accepting", "him", ".", "Then", "Jesus", "said", "to", "them", ",", "``", "People", "everywhere", "give", "honor", "to", "a", "prophet", ",", "except", "in", "his", "own", "town", ",", "with", "his", "own", "people", ",", "or", "in", "his", "home", ".", "''", "Jesus", "was", "not", "able", "to", "do", "any", "miracles", "there", "except", "the", "healing", "of", "some", "sick", "people", "by", "laying", "his", "hands", "on", "them", ".", "He", "was", "surprised", "that", "the", "people", "there", "had", "no", "faith", ".", "Then", "he", "went", "to", "other", "villages", "in", "that", "area", "and", "taught", ".", "Jesus", "called", "his", "twelve", "apostles", "together", ".", "He", "sent", "them", "out", "in", "groups", "of", "two", "and", "gave", "them", "power", "over", "evil", "spirits", ".", "This", "is", "what", "he", "told", "them", ":", "``", "Take", "nothing", "for", "your", "trip", "except", "a", "stick", "for", "walking", ".", "Take", "no", "bread", ",", "no", "bag", ",", "and", "no", "money", ".", "You", "can", "wear", "sandals", ",", "but", "do", "n't", "take", "extra", "clothes", ".", "When", "you", "enter", "a", "house", ",", "stay", "there", "until", "you", "leave", "that", "town", ".", "If", "any", "town", "refuses", "to", "accept", "you", "or", "refuses", "to", "listen", "to", "you", ",", "then", "leave", "that", "town", "and", "shake", "the", "dust", "off", "your", "feet", "as", "a", "warning", "to", "them", ".", "''", "The", "apostles", "left", "and", "went", "to", "other", "places", ".", "They", "talked", "to", "the", "people", "and", "told", "them", "to", "change", "their", "hearts", "and", "lives", ".", "They", "forced", "many", "demons", "out", "of", "people", "and", "put", "olive", "oil", "on", "many", "who", "were", "sick", "and", "healed", "them", ".", "King", "Herod", "heard", "about", "Jesus", ",", "because", "Jesus", "was", "now", "famous", ".", "Some", "people", "said", ",", "``", "He", "is", "John", "the", "Baptizer", ".", "He", "must", "have", "risen", "from", "death", ",", "and", "that", "is", "why", "he", "can", "do", "these", "miracles", ".", "''", "Other", "people", "said", ",", "``", "He", "is", "Elijah", ".", "''", "And", "others", "said", ",", "``", "He", "is", "a", "prophet", ".", "He", "is", "like", "the", "prophets", "who", "lived", "long", "ago", ".", "''"], "clusters": {"2": [[0, 1], [6, 7], [13, 14], [19, 20], [29, 30], [40, 42], [48, 49], [57, 58], [62, 63], [72, 73], [98, 99], [113, 114], [116, 117], [147, 148], [165, 166], [170, 171], [182, 183], [193, 194], [200, 201], [219, 220], [352, 353], [355, 356], [365, 366], [371, 372], [382, 383], [394, 395], [404, 405], [409, 410]], "23": [[9, 11], [195, 198], [202, 203], [210, 211], [221, 222], [227, 228], [246, 247], [259, 260], [267, 268], [278, 279], [284, 285], [295, 296], [304, 306], [313, 314], [328, 329]], "8": [[26, 28], [31, 32], [76, 77], [108, 109], [119, 120]], "7": [[50, 52], [55, 56]], "33": [[127, 129], [132, 133], [137, 138], [143, 144]], "29": [[160, 163], [168, 169]], "35": [[273, 275], [288, 290]], "10": [[316, 318], [320, 321], [323, 324]], "15": [[340, 344], [346, 347]], "3": [[374, 375], [379, 380]]}} -{"tokens": ["Herod", "heard", "these", "things", "about", "Jesus", ".", "He", "said", ",", "``", "I", "killed", "John", "by", "cutting", "off", "his", "head", ".", "Now", "he", "has", "been", "raised", "from", "death", "!", "''", "Herod", "himself", "had", "ordered", "his", "soldiers", "to", "arrest", "John", "and", "put", "him", "in", "prison", ".", "Herod", "did", "this", "to", "please", "his", "wife", "Herodias", ".", "She", "had", "been", "married", "to", "Herod", "'s", "brother", "Philip", ",", "but", "then", "Herod", "married", "her", ".", "John", "told", "Herod", ",", "``", "It", "is", "not", "right", "for", "you", "to", "be", "married", "to", "your", "brother", "'s", "wife", ".", "''", "So", "Herodias", "hated", "John", ".", "She", "wanted", "him", "dead", ",", "but", "she", "was", "not", "able", "to", "persuade", "Herod", "to", "kill", "him", ".", "Herod", "was", "afraid", "to", "kill", "John", ",", "because", "he", "knew", "that", "he", "was", "a", "good", "and", "holy", "man", ".", "So", "he", "protected", "him", ".", "He", "liked", "listening", "to", "John", ",", "although", "what", "John", "said", "left", "him", "with", "so", "many", "questions", ".", "Then", "the", "right", "time", "came", "for", "Herodias", "to", "cause", "John", "'s", "death", ".", "It", "happened", "on", "Herod", "'s", "birthday", ".", "Herod", "gave", "a", "dinner", "party", "for", "the", "most", "important", "government", "leaders", ",", "the", "commanders", "of", "his", "army", ",", "and", "the", "most", "important", "people", "in", "Galilee", ".", "The", "daughter", "of", "Herodias", "came", "to", "the", "party", "and", "danced", ".", "When", "she", "danced", ",", "Herod", "and", "the", "people", "eating", "with", "him", "were", "very", "pleased", ".", "So", "King", "Herod", "said", "to", "the", "girl", ",", "``", "I", "will", "give", "you", "anything", "you", "want", ".", "''", "He", "promised", "her", ",", "``", "Anything", "you", "ask", "for", "I", "will", "give", "to", "you", "--", "even", "half", "of", "my", "kingdom", ".", "''", "The", "girl", "went", "to", "her", "mother", "and", "asked", ",", "``", "What", "should", "I", "ask", "King", "Herod", "to", "give", "me", "?", "''", "Her", "mother", "answered", ",", "``", "Ask", "for", "the", "head", "of", "John", "the", "Baptizer", ".", "''", "So", "right", "then", "the", "girl", "went", "back", "in", "to", "the", "king", ".", "She", "said", "to", "him", ",", "``", "Please", "give", "me", "the", "head", "of", "John", "the", "Baptizer", ".", "Bring", "it", "to", "me", "now", "on", "a", "plate", ".", "''", "King", "Herod", "was", "very", "sad", ",", "but", "he", "did", "n't", "want", "to", "break", "the", "promise", "he", "had", "made", "to", "her", "in", "front", "of", "his", "guests", ".", "So", "he", "sent", "a", "soldier", "to", "cut", "off", "John", "'s", "head", "and", "bring", "it", "to", "him", ".", "The", "soldier", "went", "and", "cut", "off", "John", "'s", "head", "in", "the", "prison", ".", "He", "brought", "the", "head", "back", "on", "a", "plate", "and", "gave", "it", "to", "the", "girl", ",", "and", "the", "girl", "gave", "it", "to", "her", "mother", ".", "John", "'s", "followers", "heard", "about", "what", "happened", ",", "so", "they", "came", "and", "got", "John", "'s", "body", "and", "put", "it", "in", "a", "tomb", ".", "The", "apostles", "Jesus", "had", "sent", "out", "came", "back", "to", "him", ".", "They", "gathered", "around", "him", "and", "told", "him", "about", "all", "they", "had", "done", "and", "taught", ".", "Jesus", "and", "his", "followers", "were", "in", "a", "very", "busy", "place", ".", "There", "were", "so", "many", "people", "that", "he", "and", "his", "followers", "did", "not", "even", "have", "time", "to", "eat", ".", "He", "said", "to", "them", ",", "``", "Come", "with", "me", ".", "We", "will", "go", "to", "a", "quiet", "place", "to", "be", "alone", ".", "There", "we", "will", "get", "some", "rest", ".", "''"], "clusters": {"2": [[0, 1], [7, 8], [11, 12], [29, 31], [33, 34], [44, 45], [65, 66], [71, 72], [79, 80], [107, 108], [112, 113], [120, 121], [132, 133], [136, 137], [147, 148], [169, 171], [173, 174], [214, 215], [226, 228], [234, 235], [243, 244], [252, 253], [261, 262], [279, 281], [310, 312], [316, 317], [339, 341], [346, 347], [354, 355], [366, 367], [380, 381]], "19": [[5, 6], [451, 452], [456, 457], [459, 460], [497, 498], [505, 506]], "32": [[13, 14], [21, 22], [37, 38], [40, 41], [69, 70], [93, 94], [97, 98], [110, 111], [117, 118], [123, 124], [134, 135], [140, 141], [144, 145], [162, 164]], "27": [[17, 19], [293, 299], [322, 328], [330, 331], [373, 376], [378, 379], [388, 391], [397, 399], [405, 406], [414, 415]], "26": [[32, 33], [46, 47]], "17": [[49, 52], [53, 54], [67, 68], [84, 88], [91, 92], [95, 96], [101, 102], [159, 160], [269, 271], [286, 288], [416, 418]], "13": [[58, 62]], "0": [[175, 178], [205, 207]], "31": [[179, 198], [216, 221], [362, 364]], "14": [[199, 203], [211, 212], [230, 232], [237, 238], [239, 240], [245, 246], [249, 250], [256, 257], [265, 267], [277, 278], [283, 284], [304, 306], [313, 314], [321, 322], [332, 333], [358, 359], [407, 409], [411, 413]], "9": [[368, 370], [382, 384], [395, 396]], "22": [[419, 422], [428, 429]], "30": [[432, 435], [437, 438]], "18": [[442, 448], [453, 454], [462, 463], [500, 501]], "1": [[468, 472], [485, 489], [507, 508], [519, 520]]}} +{"tokens": ["@@START@@", "in", "the", "summer", "of", "@@NUM@@", "a", "picture", "that", "people", "have", "long", "been", "looking", "forward", "to", "started", "emerging", "with", "frequency", "in", "various", "major", "hong", "kong", "media", "@@END@@", "@@START@@", "with", "their", "unique", "charm", "these", "well", "-", "known", "cartoon", "images", "once", "again", "caused", "hong", "kong", "to", "be", "a", "focus", "of", "worldwide", "attention", "@@END@@", "@@START@@", "the", "world", "'s", "fifth", "disney", "park", "will", "soon", "open", "to", "the", "public", "here", "@@END@@", "@@START@@", "the", "most", "important", "thing", "about", "disney", "is", "that", "it", "is", "a", "global", "brand", "@@END@@", "@@START@@", "well", "for", "several", "years", "although", "it", "was", "still", "under", "construction", "and", "er", "not", "yet", "open", "it", "can", "be", "said", "that", "many", "people", "have", "viewed", "hong", "kong", "with", "new", "respect", "@@END@@", "@@START@@", "then", "welcome", "to", "the", "official", "writing", "ceremony", "of", "hong", "kong", "disneyland", "@@END@@", "@@START@@", "the", "construction", "of", "hong", "kong", "disneyland", "began", "two", "years", "ago", "in", "@@NUM@@", "@@END@@", "@@START@@", "in", "january", "of", "that", "year", "the", "hong", "kong", "government", "turned", "over", "to", "disney", "corporation", "@@NUM@@", "hectares", "of", "land", "at", "the", "foot", "of", "lantau", "island", "that", "was", "obtained", "following", "the", "largest", "land", "reclamation", "project", "in", "recent", "years", "@@END@@", "@@START@@", "one", "@@END@@", "@@START@@", "since", "then", "this", "area", "has", "become", "a", "prohibited", "zone", "in", "hong", "kong", "@@END@@", "@@START@@", "as", "its", "neighbor", "on", "lantau", "island", "hong", "kong", "international", "airport", "had", "to", "change", "its", "flight", "routes", "to", "make", "this", "area", "a", "no", "fly", "zone", "@@END@@", "@@START@@", "mickey", "mouse", "'s", "new", "home", "settling", "on", "chinese", "land", "for", "the", "first", "time", "has", "captured", "worldwide", "attention", "@@END@@", "@@START@@", "there", "'s", "only", "one", "month", "left", "before", "the", "opening", "of", "hong", "kong", "disneyland", "on", "september", "@@NUM@@", "@@END@@", "@@START@@", "the", "subway", "to", "disney", "has", "already", "been", "constructed", "@@END@@", "@@START@@", "at", "subway", "stations", "passengers", "will", "frequently", "press", "the", "station", "for", "disney", "on", "ticket", "machines", "trying", "to", "purchase", "tickets", "to", "enjoy", "the", "park", "when", "it", "first", "opens", "@@END@@", "@@START@@", "meanwhile", "the", "disney", "subway", "station", "is", "scheduled", "to", "open", "on", "the", "same", "day", "as", "the", "park", "@@END@@", "@@START@@", "for", "two", "years", "disney", "has", "constantly", "maintained", "its", "mystery", "@@END@@", "@@START@@", "no", "media", "have", "been", "allowed", "to", "enter", "for", "photos", "@@END@@", "@@START@@", "we", "took", "a", "taxi", "along", "the", "path", "of", "the", "highway", "that", "heads", "toward", "disney", "trying", "to", "experience", "this", "mysterious", "park", "from", "close", "by", "@@END@@", "@@START@@", "however", "before", "any", "of", "the", "disney", "symbols", "were", "in", "sight", "the", "car", "was", "stopped", "by", "a", "security", "guard", "at", "the", "intersection", "of", "the", "road", "towards", "disney", "@@END@@", "@@START@@", "on", "our", "way", "back", "the", "taxi", "driver", "gave", "us", "an", "explanation", "after", "understanding", "our", "intentions", "@@END@@", "@@START@@", "er", "according", "to", "what", "the", "security", "guard", "said", "for", "the", "time", "before", "everything", "is", "officially", "opened", "no", "cars", "can", "enter", "unless", "they", "have", "special", "permission", "@@END@@", "@@START@@", "no", "one", "can", "enter", "otherwise", "@@END@@", "@@START@@", "video", "recording", "is", "especially", "forbidden", "@@END@@", "@@START@@", "ah", "everything", "is", "top", "secret", "@@END@@", "@@START@@", "if", "pictures", "are", "taken", "without", "permission", "that", "is", "to", "say", "it", "will", "at", "all", "times", "be", "pursued", "by", "legal", "action", "a", "big", "hassle", "@@END@@", "@@START@@", "although", "disney", "corporation", "chose", "hong", "kong", "as", "the", "venue", "for", "the", "chinese", "disney", "park", "what", "they", "are", "actually", "most", "excited", "about", "is", "the", "mainland", "china", "tourist", "market", "@@END@@"], "clusters": {"23": [[23, 25], [41, 43], [106, 108], [146, 148], [191, 193], [483, 485]], "18": [[29, 30], [32, 38]], "12": [[52, 58], [87, 88], [121, 124], [129, 132], [221, 226], [250, 253], [261, 262], [278, 279], [288, 290], [291, 292], [298, 299], [317, 318], [321, 322], [349, 350], [353, 356], [386, 387], [489, 493]], "50": [[72, 73], [75, 76], [152, 154], [366, 367], [480, 482], [494, 495]], "13": [[137, 138], [143, 145]], "14": [[154, 176], [183, 185], [196, 197], [213, 215]], "8": [[199, 201]], "29": [[201, 205], [208, 209]], "4": [[254, 256], [306, 312]], "40": [[336, 337], [390, 391], [397, 398], [402, 403]], "21": [[338, 340], [371, 373]], "31": [[376, 379], [410, 413]]}} +{"tokens": ["@@START@@", "since", "the", "implementation", "of", "the", "individual", "visit", "scheme", "between", "hong", "kong", "and", "the", "mainland", "more", "and", "more", "mainland", "tourists", "are", "coming", "to", "visit", "hong", "kong", "@@END@@", "@@START@@", "from", "the", "beginning", "up", "till", "now", "more", "than", "seven", "million", "individual", "tourists", "have", "come", "to", "hong", "kong", "@@END@@", "@@START@@", "well", "we", "now", "er", "believe", "more", "will", "be", "coming", "@@END@@", "@@START@@", "at", "this", "point", "it", "has", "been", "about", "two", "years", "@@END@@", "@@START@@", "also", "the", "current", "number", "of", "@@NUM@@", "cities", "will", "be", "increased", "@@END@@", "@@START@@", "hong", "kong", "was", "developed", "from", "a", "fishing", "harbor", "one", "hundred", "years", "ago", "to", "become", "today", "'s", "international", "metropolis", "@@END@@", "@@START@@", "here", "eastern", "and", "western", "cultures", "have", "gathered", "and", "the", "new", "and", "the", "old", "coexist", "@@END@@", "@@START@@", "when", "in", "hong", "kong", "you", "can", "wander", "among", "skyscrapers", "heartily", "enjoy", "shopping", "sprees", "in", "well", "known", "stores", "and", "malls", "for", "goods", "from", "various", "countries", "and", "taste", "delicious", "snacks", "from", "all", "over", "the", "world", "at", "tea", "shops", "or", "at", "street", "stands", "in", "mong", "kok", "@@END@@", "@@START@@", "you", "can", "go", "to", "burn", "incense", "and", "make", "a", "vow", "at", "the", "repulse", "bay", "where", "all", "deities", "gather", "@@END@@", "@@START@@", "you", "can", "enjoy", "the", "most", "charming", "sun", "filled", "sandy", "beaches", "in", "hong", "kong", "@@END@@", "@@START@@", "you", "can", "ascend", "victoria", "peak", "to", "get", "a", "panoramic", "view", "of", "victoria", "harbor", "'s", "beautiful", "scenery", "@@END@@", "@@START@@", "or", "hop", "onto", "a", "trolley", "with", "over", "a", "century", "of", "history", "and", "feel", "the", "city", "'s", "blend", "of", "the", "old", "and", "the", "modern", "in", "slow", "motion", "@@END@@"], "clusters": {"19": [[10, 12], [24, 26], [43, 45], [81, 83], [119, 121], [193, 195], [228, 231]]}} diff --git a/kglm/tests/fixtures/kglm.model.tar.gz b/kglm/tests/fixtures/kglm.model.tar.gz index 4f1b44f..3ff743b 100644 Binary files a/kglm/tests/fixtures/kglm.model.tar.gz and b/kglm/tests/fixtures/kglm.model.tar.gz differ diff --git a/kglm/tests/models/entity_nlm_test.py b/kglm/tests/models/entity_nlm_test.py index 4b5e9d4..8b62e94 100644 --- a/kglm/tests/models/entity_nlm_test.py +++ b/kglm/tests/models/entity_nlm_test.py @@ -32,3 +32,65 @@ def setUp(self): def test_model_can_train_save_and_load(self): self.ensure_model_can_train_save_and_load(self.param_file, gradients_to_ignore=['_dummy_context_embedding']) + + def test_annotation_logp(self): + batch_size = 2 + + # Need to reset the states + reset = torch.ByteTensor([1] * batch_size) + self.model.reset_states(reset) + + # Apply to random hidden state + hidden = torch.randn(batch_size, self.model._embedding_dim) + beam_states = [self.model._dynamic_embeddings.beam_state()] + logp = self.model._annotation_logp(hidden, timestep=0, beam_states=beam_states) + + # Check that output has correct shape + assert tuple(logp.shape) == (batch_size, 1, self.model.num_possible_annotations) + + def test_adjust_for_ongoing_mentions(self): + batch_size = 2 + k = 3 + + # Construct an example where the top-beam state for the second sequence in the batch is an ongoing mention + logp = torch.zeros(batch_size, k, self.model.num_possible_annotations) + output = { + 'entity_ids': torch.LongTensor([[0, 0, 0], [3, 0, 0]]), + 'mention_lengths': torch.LongTensor([[0, 0, 0], [2, 0, 0]]) + } + + # See that adjustment works + logp = self.model._adjust_for_ongoing_mentions(logp, output) + assert logp[0, 0, 0] == 0.0 # Should be unaffected + assert logp[1, 0, 0] == -float('inf') # Should be affected + + # Only element with probability should have entity id == 3 and mention length == 2 + pred = logp[1, 0].argmax() + assert self.model.entity_id_lookup[pred] == 3 + assert self.model.mention_length_lookup[pred] == 1 + + def test_top_k_annotations(self): + batch_size = 2 + k = 3 + + # Check works correctly at start (e.g. if beam size is 1) + logp = torch.randn(batch_size, 1, self.model.num_possible_annotations) + annotations = self.model._top_k_annotations(logp, k) + + assert tuple(annotations['logp'].shape) == (batch_size, k) + assert torch.allclose(annotations['backpointers'], torch.zeros(batch_size, k, dtype=torch.int64)) + + # Check works correctly for other timesteps (e.g. previous beam size is k) + logp = torch.randn(batch_size, k, self.model.num_possible_annotations) + annotations = self.model._top_k_annotations(logp, k) + + assert tuple(annotations['logp'].shape) == (batch_size, k) + + def test_beam_search(self): + batch_size = 2 + seq_len = 10 + k = 3 + vocab_size = self.model.vocab.get_vocab_size('tokens') + source = {'tokens': torch.randint(vocab_size, size=(batch_size, seq_len))} + reset = torch.ones(batch_size, dtype=torch.uint8) + out = self.model.beam_search(source, reset, k) diff --git a/kglm/tests/models/kglm_test.py b/kglm/tests/models/kglm_test.py index 54307c6..c7a2c91 100644 --- a/kglm/tests/models/kglm_test.py +++ b/kglm/tests/models/kglm_test.py @@ -8,7 +8,8 @@ from kglm.common.testing import KglmModelTestCase from kglm.data.dataset_readers.enhanced_wikitext import EnhancedWikitextKglmReader from kglm.models.kglm import Kglm -from kglm.models.kglm_disc import KglmDisc +from kglm.models.kglm_disc import KglmDisc, KglmBeamState +from kglm.modules import RecentEntities class KglmTest(KglmModelTestCase): @@ -52,7 +53,7 @@ def test_sample(self): instances = list(reader.read(dataset_file)) iterator = DataIterator.from_params(generator_params['iterator']) iterator.index_with(self.model.vocab) - batch, _ = next(iterator(instances, shuffle=False)) + batch = next(iterator(instances, shuffle=False)) self.model.sample(**batch) @@ -79,7 +80,7 @@ def test_sample(self): iterator = DataIterator.from_params(generator_params['iterator']) iterator.index_with(self.model.vocab) - batch, _ = next(iterator(instances, shuffle=False)) + batch = next(iterator(instances, shuffle=False)) # Samples should match (we'll test by comparing logp) torch.manual_seed(123) @@ -87,3 +88,109 @@ def test_sample(self): torch.manual_seed(123) logp2 = self.model.sample(**batch).get('logp', None) + def test_beam_search(self): + generator_params = Params.from_file("kglm/tests/fixtures/training_config/kglm.no-shortlist.json") + params = Params.from_file(self.param_file) + dataset_file = "kglm/tests/fixtures/enhanced-wikitext-test/train.jsonl" + + # Need instances from 'generative' reader! + reader_params = generator_params['dataset_reader'] + reader_params['mode'] = 'generative' + reader = DatasetReader.from_params(reader_params) + instances = list(reader.read(dataset_file)) + + iterator = DataIterator.from_params(generator_params['iterator']) + iterator.index_with(self.model.vocab) + batch = next(iterator(instances, shuffle=False)) + + # Just want to check that function does not raise an error for now. + self.model.beam_search(batch['source'], + batch['target'], + batch['reset'], + batch['metadata'], + k=5) + + def test_next_mention_type_logp(self): + # Checks whether penalty correctly applied to ongoing mentions + batch_size = 1 + num_classes = 2 + k = 2 + + # All mention types have equal prob + next_mention_type_logits = torch.ones(batch_size, num_classes) + + # First beam has an ongoing mention, second does not + recent_entities_state = self.model._recent_entities.beam_state() + ongoing_0 = torch.ones(batch_size, dtype=torch.uint8) + ongoing_1 = torch.zeros(batch_size, dtype=torch.uint8) + beam_states = [ + KglmBeamState(recent_entities=recent_entities_state, ongoing=ongoing_0), + KglmBeamState(recent_entities=recent_entities_state, ongoing=ongoing_1) + ] + + next_mention_type_logp = self.model._next_mention_type_logp(next_mention_type_logits, beam_states) + # Log probabilities should be same on first beam, and different on second. + assert torch.allclose(next_mention_type_logp[0, 0, 0], next_mention_type_logp[0, 0, 1]) + assert not torch.allclose(next_mention_type_logp[0, 1, 0], next_mention_type_logp[0, 1, 1]) + # Log probability of first state (e.g., non-ongoing) should be close to 0.0 on second beam. + assert torch.allclose(next_mention_type_logp[0, 1, 0], torch.tensor(0.0)) + + def test_next_new_entity_logp(self): + # Checks whether penalty correctly applied to previously mentioned entities + batch_size = 1 + num_entities = 2 + k = 2 + + # All next entities have equal prob + next_new_entity_logits = torch.ones(batch_size, num_entities) + + # First entity is previously mentioned on first beam. + # No previous mentions on second beam. + ongoing = None # Value doesn't matter + recent_entities_state_0 = {'remaining': [{0 : None}]} + recent_entities_state_1 = {'remaining': [{}]} + beam_states = [ + KglmBeamState(recent_entities=recent_entities_state_0, ongoing=ongoing), + KglmBeamState(recent_entities=recent_entities_state_1, ongoing=ongoing)] + + next_new_entity_logp = self.model._next_new_entity_logp(next_new_entity_logits, beam_states) + # Log probabilities should be different on first beam, and same on second. + assert not torch.allclose(next_new_entity_logp[0, 0, 0], next_new_entity_logp[0, 0, 1]) + assert torch.allclose(next_new_entity_logp[0, 1, 0], next_new_entity_logp[0, 1, 1]) + # Log probability of non-recent entity should be close to 0.0 on first beam. + assert torch.allclose(next_new_entity_logp[0, 0, 1], torch.tensor(0.0)) + + def test_next_related_entity_logp(self): + # Checks that: + # * There is no probability mass if there are no candidates + # * Probability distribution is valid if there are candidates + # * Annotations look correct (e.g., parents ids are consistent) + batch_size = 1 + k = 2 + + next_encoded_head = torch.randn((batch_size, self.model.entity_embedding_dim)) + next_encoded_relation = torch.randn((batch_size, self.model.entity_embedding_dim)) + ongoing = None # Value doesn't matter + + # NOTE: `parent_id` = 5 chosen since this node in the knowledge graph has a relatively + # small number of outgoing edges. + recent_entities_state_0 = {'remaining': [{5 : None}]} + recent_entities_state_1 = {'remaining': [{}]} + recent_entities_state_2 = {'remaining': [{5: None, 6: None}]} + beam_states = [ + KglmBeamState(recent_entities=recent_entities_state_0, ongoing=ongoing), + KglmBeamState(recent_entities=recent_entities_state_1, ongoing=ongoing), + KglmBeamState(recent_entities=recent_entities_state_2, ongoing=ongoing) + ] + + logp, annotations = self.model._next_related_entity_logp(next_encoded_head, + next_encoded_relation, + beam_states) + # Only first and last states will have probability mass + assert torch.allclose(logp[0, 0].exp().sum(), torch.tensor(1.0)) + assert torch.allclose(logp[0, 1].exp().sum(), torch.tensor(0.0)) + assert torch.allclose(logp[0, 2].exp().sum(), torch.tensor(1.0)) + + assert annotations['parent_ids'][0, 0].unique().size(0) == 2 # ids: 0, 5 + assert annotations['parent_ids'][0, 1].unique().size(0) == 1 # ids: 0 + assert annotations['parent_ids'][0, 2].unique().size(0) == 2 # ids: 0, 5, 6 diff --git a/kglm/training/trainer.py b/kglm/training/trainer.py index 8c00caa..5bd6c33 100644 --- a/kglm/training/trainer.py +++ b/kglm/training/trainer.py @@ -288,7 +288,7 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]: train_generator_tqdm = Tqdm.tqdm(raw_train_generator, total=num_training_batches) cumulative_batch_size = 0 - for batch, lr_mult in train_generator_tqdm: + for batch in train_generator_tqdm: batches_this_epoch += 1 self._batch_num_total += 1 batch_num_total = self._batch_num_total @@ -314,12 +314,6 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]: if self._learning_rate_scheduler: self._learning_rate_scheduler.step_batch(batch_num_total) - # We dynamically adjust the learning rate to account for slight variations in the input - # sequences - original_lr = self.optimizer.param_groups[0]['lr'] - batch_lr = original_lr * lr_mult - self.optimizer.param_groups[0]['lr'] = batch_lr - if self._tensorboard.should_log_histograms_this_batch(): # get the magnitude of parameter updates for logging # We need a copy of current parameters to compute magnitude of updates, @@ -336,8 +330,6 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]: else: self.optimizer.step() - self.optimizer.param_groups[0]['lr'] = original_lr - # Update moving averages if self._moving_average is not None: self._moving_average.apply(batch_num_total) @@ -401,7 +393,7 @@ def _validation_loss(self) -> Tuple[float, int]: total=num_validation_batches) batches_this_epoch = 0 val_loss = 0 - for batch, _ in val_generator_tqdm: + for batch in val_generator_tqdm: loss = self.batch_loss(batch, for_training=False) if loss is not None: