diff --git a/experiments/entity_disc_conll2012.jsonnet b/experiments/entity_disc_conll2012.jsonnet index 72bb063..2cee49d 100644 --- a/experiments/entity_disc_conll2012.jsonnet +++ b/experiments/entity_disc_conll2012.jsonnet @@ -2,7 +2,7 @@ "vocabulary": { "type": "extended", "extend": false, - "directory_path": "results/entity-nlm-conll2012-fixed/vocabulary" + "directory_path": "data/vocabulary" }, "dataset_reader": { "type": "conll2012_jsonl", @@ -36,7 +36,7 @@ }, "iterator": { "type": "fancy", - "batch_size": 16, + "batch_size": 343, "split_size": 30, "splitting_keys": [ "source", @@ -45,13 +45,26 @@ "mention_lengths" ], }, + "validation_iterator": { + "type": "fancy", + "batch_size": 343, + "split_size": 128, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + "truncate": false + }, "trainer": { "type": "lm", - "num_epochs": 40, + "num_epochs": 400, "cuda_device": 0, "optimizer": { "type": "adam", "lr": 1e-3 - } + }, + "validation_metric": "+eid_acc" } } diff --git a/experiments/entity_disc_conll2012_mini.jsonnet b/experiments/entity_disc_conll2012_mini.jsonnet new file mode 100644 index 0000000..5736946 --- /dev/null +++ b/experiments/entity_disc_conll2012_mini.jsonnet @@ -0,0 +1,70 @@ +{ + "vocabulary": { + "type": "extended", + "extend": false, + "directory_path": "data/vocabulary-mini" + }, + "dataset_reader": { + "type": "conll2012_jsonl", + "token_indexers": { + "tokens": { + "type": "single_id", + "lowercase_tokens": true + } + } + }, + "train_data_path": "data/conll-2012/processed/train-mini.jsonl", + "validation_data_path": "data/conll-2012/processed/dev-mini.jsonl", + "model": { + "type": "entitydisc", + "text_field_embedder": { + "token_embedders": { + "tokens": { + "type": "embedding", + "embedding_dim": 256, + "trainable": true + }, + }, + }, + "embedding_dim": 256, + "hidden_size": 256, + "num_layers": 1, + "max_mention_length": 50, + "max_embeddings": 10, + "dropout_rate": 0.4, + "variational_dropout_rate": 0.1 + }, + "iterator": { + "type": "fancy", + "batch_size": 16, + "split_size": 15, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + }, + "validation_iterator": { + "type": "fancy", + "batch_size": 16, + "split_size": 15, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + "truncate": false + }, + "trainer": { + "type": "lm", + "num_epochs": 400, + "cuda_device": 0, + "optimizer": { + "type": "adam", + "lr": 1e-3 + }, + "validation_metric": "+eid_acc" + } +} diff --git a/experiments/entity_disc_conll2012_no_peeking.jsonnet b/experiments/entity_disc_conll2012_no_peeking.jsonnet new file mode 100644 index 0000000..5b9b5bf --- /dev/null +++ b/experiments/entity_disc_conll2012_no_peeking.jsonnet @@ -0,0 +1,71 @@ +{ + "vocabulary": { + "type": "extended", + "extend": false, + "directory_path": "data/vocabulary" + }, + "dataset_reader": { + "type": "conll2012_jsonl", + "offset": 1, + "token_indexers": { + "tokens": { + "type": "single_id", + "lowercase_tokens": true + } + } + }, + "train_data_path": "data/conll-2012/processed/train.jsonl", + "validation_data_path": "data/conll-2012/processed/dev.jsonl", + "model": { + "type": "entitydisc", + "text_field_embedder": { + "token_embedders": { + "tokens": { + "type": "embedding", + "embedding_dim": 128, + "trainable": true + }, + }, + }, + "embedding_dim": 128, + "hidden_size": 128, + "num_layers": 1, + "max_mention_length": 100, + "max_embeddings": 100, + "dropout_rate": 0.4, + "variational_dropout_rate": 0.1 + }, + "iterator": { + "type": "fancy", + "batch_size": 16, + "split_size": 15, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + }, + "validation_iterator": { + "type": "fancy", + "batch_size": 16, + "split_size": 15, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + "truncate": false + }, + "trainer": { + "type": "lm", + "num_epochs": 400, + "cuda_device": 0, + "optimizer": { + "type": "adam", + "lr": 1e-4 + }, + "validation_metric": "+eid_acc" + } +} diff --git a/experiments/entity_disc_conll2012_no_peeking_mini.jsonnet b/experiments/entity_disc_conll2012_no_peeking_mini.jsonnet new file mode 100644 index 0000000..046f380 --- /dev/null +++ b/experiments/entity_disc_conll2012_no_peeking_mini.jsonnet @@ -0,0 +1,71 @@ +{ + "vocabulary": { + "type": "extended", + "extend": false, + "directory_path": "data/vocabulary-mini" + }, + "dataset_reader": { + "type": "conll2012_jsonl", + "offset": 1, + "token_indexers": { + "tokens": { + "type": "single_id", + "lowercase_tokens": true + } + } + }, + "train_data_path": "data/conll-2012/processed/train-mini.jsonl", + "validation_data_path": "data/conll-2012/processed/dev-mini.jsonl", + "model": { + "type": "entitydisc", + "text_field_embedder": { + "token_embedders": { + "tokens": { + "type": "embedding", + "embedding_dim": 128, + "trainable": true + }, + }, + }, + "embedding_dim": 128, + "hidden_size": 128, + "num_layers": 1, + "max_mention_length": 50, + "max_embeddings": 10, + "dropout_rate": 0.4, + "variational_dropout_rate": 0.1 + }, + "iterator": { + "type": "fancy", + "batch_size": 16, + "split_size": 15, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + }, + "validation_iterator": { + "type": "fancy", + "batch_size": 16, + "split_size": 15, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + "truncate": false + }, + "trainer": { + "type": "lm", + "num_epochs": 400, + "cuda_device": 0, + "optimizer": { + "type": "adam", + "lr": 1e-4 + }, + "validation_metric": "+eid_acc" + } +} diff --git a/experiments/entity_disc_test.jsonnet b/experiments/entity_disc_test.jsonnet index 4c9f6e8..ed932e0 100644 --- a/experiments/entity_disc_test.jsonnet +++ b/experiments/entity_disc_test.jsonnet @@ -2,7 +2,7 @@ "vocabulary": { "type": "extended", "extend": false, - "directory_path": "./results/entity-nlm-wt2.fixed-vocab.dropout.2/vocabulary" + "directory_path": "data/vocabulary" }, "dataset_reader": { "type": "enhanced-wikitext", @@ -59,4 +59,4 @@ "lr": 1e-3 } } -} \ No newline at end of file +} diff --git a/experiments/entity_nlm.jsonnet b/experiments/entity_nlm.jsonnet index f2a3eca..9bfc403 100644 --- a/experiments/entity_nlm.jsonnet +++ b/experiments/entity_nlm.jsonnet @@ -4,7 +4,7 @@ }, "iterator": { "type": "fancy", - "batch_size": 60, + "batch_size": 30, "split_size": 70, "splitting_keys": [ "source", @@ -13,6 +13,18 @@ "mention_lengths" ] }, + "validation_iterator": { + "type": "fancy", + "batch_size": 30, + "split_size": 70, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + "truncate": false + }, "model": { "type": "entitynlm", "dropout_rate": 0.5, diff --git a/experiments/entity_nlm_conll2012.jsonnet b/experiments/entity_nlm_conll2012.jsonnet index 5437e38..609151b 100644 --- a/experiments/entity_nlm_conll2012.jsonnet +++ b/experiments/entity_nlm_conll2012.jsonnet @@ -1,11 +1,8 @@ { "vocabulary": { "type": "extended", - "max_vocab_size": { - // This does not count the @@UNKNOWN@@ token, which - // ends up being our 10,000th token. - "tokens": 9999 - } + "extend": false, + "directory_path": "data/vocabulary" }, "dataset_reader": { "type": "conll2012_jsonl", @@ -41,14 +38,26 @@ }, "iterator": { "type": "fancy", - "batch_size": 16, - "split_size": 30, + "batch_size": 512, + "split_size": 15, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + }, + "validation_iterator": { + "type": "fancy", + "batch_size": 512, + "split_size": 15, "splitting_keys": [ "source", "entity_types", "entity_ids", "mention_lengths" ], + "truncate": false }, "trainer": { "type": "lm", diff --git a/experiments/entity_nlm_conll2012_mini.jsonnet b/experiments/entity_nlm_conll2012_mini.jsonnet new file mode 100644 index 0000000..e35a67b --- /dev/null +++ b/experiments/entity_nlm_conll2012_mini.jsonnet @@ -0,0 +1,71 @@ +{ + "vocabulary": { + "type": "extended", + "extend": false, + "directory_path": "data/vocabulary-mini" + }, + "dataset_reader": { + "type": "conll2012_jsonl", + "token_indexers": { + "tokens": { + "type": "single_id", + "lowercase_tokens": true + } + } + }, + "train_data_path": "data/conll-2012/processed/train-mini.jsonl", + "validation_data_path": "data/conll-2012/processed/dev-mini.jsonl", + "datasets_for_vocab_creation": ["train"], + "model": { + "type": "entitynlm", + "text_field_embedder": { + "token_embedders": { + "tokens": { + "type": "embedding", + "embedding_dim": 256, + "trainable": true + }, + }, + }, + "embedding_dim": 256, + "hidden_size": 256, + "num_layers": 1, + "max_mention_length": 50, + "max_embeddings": 10, + "tie_weights": true, + "dropout_rate": 0.4, + "variational_dropout_rate": 0.1 + }, + "iterator": { + "type": "fancy", + "batch_size": 16, + "split_size": 15, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + }, + "validation_iterator": { + "type": "fancy", + "batch_size": 16, + "split_size": 15, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + "truncate": false + }, + "trainer": { + "type": "lm", + "num_epochs": 400, + "cuda_device": 0, + "optimizer": { + "type": "adam", + "lr": 1e-3 + } + } +} diff --git a/kglm/commands/evaluate_perplexity.py b/kglm/commands/evaluate_perplexity.py index 68accc5..21dfeeb 100644 --- a/kglm/commands/evaluate_perplexity.py +++ b/kglm/commands/evaluate_perplexity.py @@ -14,6 +14,7 @@ from allennlp.models import Model from allennlp.models.archival import load_archive from allennlp.nn import util +import numpy as np import torch logger = logging.getLogger(__name__) @@ -50,81 +51,199 @@ def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argpar default="", help='a JSON structure used to override the experiment configuration') - subparser.add_argument('--batch-weight-key', - type=str, - default="", - help='If non-empty, name of metric used to weight the loss on a per-batch basis.') + subparser.add_argument('--batch-size', + type=int, + default=None, + help='Batch size (default: whatever iterator was set to)') + + subparser.add_argument('--split-size', + type=int, + default=None, + help='Split size (default: whatever iterator was set to)') subparser.add_argument('--num-samples', type=int, - default=100, + default=10000, help='Number of importance samples to draw.') + + subparser.add_argument('--samples-per-batch', + type=int, + default=1, + help='Number of importance samples to draw.') + + subparser.add_argument('--temperature', + type=float, + default=1.0) subparser.set_defaults(func=evaluate_from_args) + subparser.add_argument('--offset', + action='store_true') + subparser.set_defaults(func=evaluate_from_args) return subparser +PRESERVED_FIELDS = {'source', 'reset'} + + +def _offset(sample, held_over_data): + batch_size = sample['reset'].size(0) + new_sample = {'source': sample['source'], + 'reset': sample['reset']} + new_held_over_data = {} + for field in sample: + if field in PRESERVED_FIELDS: + continue + if held_over_data is None: + prefix = sample[field].new_zeros(batch_size) + else: + prefix = held_over_data[field] + new_sample[field] = torch.cat((prefix.unsqueeze(1), sample[field][:,:-1]), dim=1) + new_held_over_data[field] = sample[field][:,-1] + return new_sample, new_held_over_data + + +def tile(t, amount): + if isinstance(t, torch.Tensor): + args = [1 for _ in t.shape] + args[0] = amount + return t.repeat(*args) + elif isinstance(t, dict): + return {k: tile(v, amount) for k, v in t.items()} + + +def logsumexp(prev: torch.FloatTensor, + current: torch.FloatTensor, + i: int, + samples_per_batch: int): + # NOTE: n is number of samples + current_avg = current.view(samples_per_batch, -1).sum(dim=-1).logsumexp(dim=0) - np.log(samples_per_batch).item() + if prev is None: + return current_avg + a = torch.max(prev, current_avg) + sumexp = torch.exp(prev - a) * i / (i + 1) + torch.exp(current_avg - a) / (i + 1) + return a + torch.log(sumexp) + + def evaluate_perplexity(model: Model, sampler: Model, num_samples: int, instances: Iterator[Instance], data_iterator: DataIterator, - cuda_device: int) -> Dict[str, Any]: + cuda_device: int, + temperature: float = 1.0, + offset: bool = False, + samples_per_batch: int = 1) -> Dict[str, Any]: check_for_gpu(cuda_device) logger.info('Iterating over dataset') - with torch.no_grad(): + # summands = [] + # penalized_summands = [] + trajectory = np.zeros(num_samples // samples_per_batch) + individual_estimates = np.zeros(num_samples // samples_per_batch) + s_probs = np.zeros((348, num_samples // samples_per_batch)) + + + weight = None - summands = [] - penalized_summands = [] + for i in range(num_samples // samples_per_batch): + iterator = data_iterator(instances, num_epochs=1, shuffle=False) + generator_tqdm = Tqdm.tqdm(iterator, total=0) - for i in range(num_samples): - iterator = data_iterator(instances, num_epochs=1, shuffle=False) - generator_tqdm = Tqdm.tqdm(iterator, total=0) + model.eval() + sampler.eval() + sampler._state = None - model.eval() - sampler.eval() + summand = None + denom = None + #summand = torch.tensor(0.0) + # penalized_summand = torch.tensor(0.0) - summand = 0.0 - penalized_summand = 0.0 - denom = 0 - for batch, _ in generator_tqdm: + held_over_data = None - batch = util.move_to_device(batch, cuda_device) + for batch, _ in generator_tqdm: - # We need sequence length to help compute perplexity - n_tokens = util.get_text_field_mask(batch['source']).float().sum().item() + # We need sequence length to help compute perplexity + n_tokens = util.get_text_field_mask(batch['source']).float().sum(dim=-1) + if denom is None: + denom = n_tokens + else: denom += n_tokens - # Draw a sample - sampler_output = sampler.sample(**batch) - sample_logp = sampler_output['logp'] - sample = sampler_output['sample'] + summand = util.move_to_device(summand, cuda_device) + batch = util.move_to_device(batch, cuda_device) - # Evaluate on sample + # Tile if that's what we're doing + if samples_per_batch > 1: + batch = tile(batch, samples_per_batch) + + # Draw a sample + with torch.no_grad(): + sampler_output = sampler.sample(**batch, + temperature=temperature, + offset=offset) + sample_logp = sampler_output['logp'] + sample = sampler_output['sample'] + + if offset: + sample, held_over_data = _offset(sample, held_over_data) + + # Evaluate on sample + with torch.no_grad(): model_output = model(**sample) - model_logp = model_output['logp'] - model_penalized_logp = model_output['penalized_logp'] - summand += (model_logp - sample_logp).item() - penalized_summand += (model_penalized_logp - sample_logp).item() - - summands.append(summand) - penalized_summands.append(penalized_summand) - t = torch.tensor(summands) - p = torch.tensor(penalized_summands) - t_sum = torch.logsumexp(t, dim=0) - p_sum = torch.logsumexp(p, dim=0) - sum_logp = (t_sum - math.log(i+1)).item() - sum_logp_penalized = (p_sum - math.log(i+1)).item() - ppl = math.exp(-sum_logp / denom) - upp = math.exp(-sum_logp_penalized / denom) - - print('PPL: %f' % ppl) - print('UPP: %f' % upp) - - metrics = {'ppl': ppl, 'upp': upp} + + model_logp = model_output['logp'] + if summand is None: + summand = (model_logp - sample_logp) + else: + summand += (model_logp - sample_logp) + + # model_penalized_logp = model_output['penalized_logp'] + # penalized_summand += (model_penalized_logp - sample_logp) + + # generator_tqdm.set_description('Instantaneous PPL: %0.4f' % torch.exp((sample_logp - model_logp) / n_tokens).item()) + + + current_avg = summand.view(samples_per_batch, -1).sum(dim=-1).logsumexp(dim=0) - np.log(samples_per_batch).item() + instance_ppl = torch.exp(-current_avg.sum() / denom.sum()) + + weight = logsumexp(weight, summand, i, samples_per_batch) + ppl = torch.exp(-weight / denom.sum()) + + individual_estimates[i] = instance_ppl.item() + trajectory[i] = ppl.item() + + s_probs[:, i] = torch.exp(-summand.cpu() / denom.cpu()).numpy() + # summands.append(summand) + # # penalized_summands.append(penalized_summand) + # # if i == 0: + # # t = summand.unsqueeze(0) + # # p = penalized_summand.unsqueeze(0) + # # else: + # # t = torch.stack(summands, dim=0) + # # # p = torch.stack(penalized_summands, dim=0) + # t = torch.cat(summands, dim=0) + # t_sum = torch.logsumexp(t, dim=0) + # # p_sum = torch.logsumexp(p, dim=0) + # sum_logp = (t_sum - math.log((i+1)*1000)).item() + # # sum_logp_penalized = (p_sum - math.log((i+1)*1000)).item() + # ppl = math.exp(-sum_logp / 659) + # # upp = math.exp(-sum_logp_penalized / denom) + + # trajectory[i] = ppl + # # individual_estimates[i] = math.exp(-summand.item() / denom) + + # print('PPL: %f' % ppl) + # # print('UPP: %f' % upp) + + metrics = { + 'ppl': ppl, + # 'upp': upp, + 'trajectory': trajectory, + 'individual_estimates': individual_estimates, + 's_probs': s_probs + } return metrics def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]: @@ -158,10 +277,16 @@ def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]: # To avoid hairy issues with splitting, we opt to use a basic iterator so that we can # generate samples for entire sequences. iterator_params = config.pop('iterator', 'None') + if args.batch_size is not None: + iterator_params['batch_size'] = args.batch_size + if args.split_size is not None: + iterator_params['split_size'] = args.split_size + iterator_params['truncate'] = False iterator = DataIterator.from_params(iterator_params) iterator.index_with(model.vocab) - iterator.eval() - metrics = evaluate_perplexity(model, sampler, args.num_samples, instances, iterator, args.cuda_device) + metrics = evaluate_perplexity(model, sampler, args.num_samples, instances, + iterator, args.cuda_device, args.temperature, + args.offset, args.samples_per_batch) logger.info('Finished evaluating.') logger.info('Metrics:') @@ -170,7 +295,8 @@ def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]: output_file = args.output_file if output_file: - with open(output_file, 'w') as f: - json.dump(metrics, f, indent=4) + np.save(output_file + '.trajectory.npy', metrics['trajectory']) + np.save(output_file + '.individual_estimates.npy', metrics['individual_estimates']) + np.save(output_file + '.s_probs.npy', metrics['s_probs']) return metrics diff --git a/kglm/data/dataset_readers/conll2012.py b/kglm/data/dataset_readers/conll2012.py index e1fee39..f5e58e7 100644 --- a/kglm/data/dataset_readers/conll2012.py +++ b/kglm/data/dataset_readers/conll2012.py @@ -212,9 +212,11 @@ class Conll2012JsonlReader(DatasetReader): def __init__(self, token_indexers: Dict[str, TokenIndexer] = None, replace_numbers: bool = True, + offset: int = 0, lazy: bool = False) -> None: super().__init__(lazy) self._replace_numbers = replace_numbers + self._offset = offset self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} @overrides @@ -239,17 +241,15 @@ def text_to_instance(self, entity_ids = np.zeros(shape=(len(tokens),)) mention_lengths = np.ones(shape=(len(tokens),)) - max_i = 0 for i, cluster in enumerate(clusters.values()): for span in cluster: start, end = span - entity_types[(start + 1):(end + 1)] = 1 # TODO: Double check: +1 due to added '@@START@@' token - entity_ids[(start + 1):(end + 1)] = i + 1 - mention_lengths[(start + 1):(end + 1)] = np.arange(end - start, 0, step=-1) - max_i = max(max_i, i) + entity_types[(start + 1 - self._offset):(end + 1 - self._offset)] = 1 + entity_ids[(start + 1 - self._offset):(end + 1 - self._offset)] = i + 1 + mention_lengths[(start + 1 - self._offset):(end + 1 - self._offset)] = np.arange(end - start, 0, step=-1) + fields['entity_types'] = SequentialArrayField(entity_types, dtype=np.uint8) fields['entity_ids'] = SequentialArrayField(entity_ids, dtype=np.int64) fields['mention_lengths'] = SequentialArrayField(mention_lengths, dtype=np.int64) - fields['entity_types'] = SequentialArrayField(entity_types, dtype=np.uint8) return Instance(fields) diff --git a/kglm/data/iterators/fancy_iterator.py b/kglm/data/iterators/fancy_iterator.py index e54643b..2c6a5df 100644 --- a/kglm/data/iterators/fancy_iterator.py +++ b/kglm/data/iterators/fancy_iterator.py @@ -94,6 +94,7 @@ def __call__(self, new_fields[name] = field else: new_fields[name] = field.empty_field() + new_fields['reset'] = SequentialArrayField(np.array(1), dtype=np.uint8) # Always reset blank instances blank_instance = Instance(new_fields) for batch in self._generate_batches(queues, blank_instance): diff --git a/kglm/models/entity_disc.py b/kglm/models/entity_disc.py index da16224..692becd 100644 --- a/kglm/models/entity_disc.py +++ b/kglm/models/entity_disc.py @@ -2,7 +2,7 @@ Discriminative version of EntityNLM for importance sampling. """ import logging -from typing import Dict, Optional, Union +from typing import Dict, List, Optional, Union from allennlp.nn.util import get_text_field_mask from allennlp.data.vocabulary import Vocabulary @@ -74,6 +74,8 @@ def __init__(self, self._num_layers = num_layers self._max_mention_length = max_mention_length self._max_embeddings = max_embeddings + self._sos_token = self.vocab.get_token_index('@@START@@', 'tokens') + self._eos_token = self.vocab.get_token_index('@@END@@', 'tokens') # Rnn Encoders. rnns: List[torch.nn.Module] = [] @@ -167,7 +169,10 @@ def forward(self, # pylint: disable=arguments-differ def sample(self, source: Dict[str, torch.Tensor], - reset: torch.ByteTensor=None) -> Dict[str, torch.Tensor]: + reset: torch.ByteTensor = None, + temperature: float = 1.0, + offset: bool = False, + **kwargs) -> Dict[str, torch.Tensor]: """ Generates a sample from the discriminative model. @@ -209,6 +214,19 @@ def sample(self, # Embed tokens and get RNN hidden state. mask = get_text_field_mask(source) + + # If not offsetting, we need to ignore contribution of @@START@@ token annotation since it + # is never used. + + if not offset: + sos_mask = source['tokens'].ne(self._sos_token) + mask = mask.byte() & sos_mask + # If offsetting, we need to ignore contribution of @@END@@ token annotation since it is + # never used. + if offset: + eos_mask = source['tokens'].ne(self._eos_token) + mask = mask.byte() & eos_mask + embeddings = self._text_field_embedder(source) current_input = embeddings hidden_list = [] @@ -235,8 +253,6 @@ def sample(self, entity_ids = torch.zeros_like(source['tokens']) mention_lengths = torch.ones_like(source['tokens']) - # Generate outputs - prev_mention_lengths = torch.ones(mask.size(0)).to(mask.device) for timestep in range(sequence_length): current_hidden = hidden[:, timestep] @@ -244,10 +260,11 @@ def sample(self, # We only predict types / ids / lengths if the previous mention is terminated. predict_mask = prev_mention_lengths == 1 predict_mask = predict_mask & mask[:, timestep].byte() + if predict_mask.any(): # Predict entity types - entity_type_logits = self._entity_type_projection(current_hidden[predict_mask]) + entity_type_logits = self._entity_type_projection(current_hidden[predict_mask]) / temperature entity_type_logp = F.log_softmax(entity_type_logits, dim=-1) entity_type_prediction_logp, entity_type_predictions = sample_from_logp(entity_type_logp) entity_type_predictions = entity_type_predictions.byte() @@ -256,11 +273,12 @@ def sample(self, # Only predict entity and mention lengths if we predicted that there was a mention predict_em = entity_types[:, timestep] & predict_mask - if predict_em.sum() > 0: + if predict_em.any(): # Predict entity ids entity_id_prediction_outputs = self._dynamic_embeddings(hidden=current_hidden, + timestep=timestep, mask=predict_em) - entity_id_logits = entity_id_prediction_outputs['logits'] + entity_id_logits = entity_id_prediction_outputs['logits'] / temperature entity_id_mask = entity_id_prediction_outputs['logit_mask'] entity_id_probs = masked_softmax(entity_id_logits, entity_id_mask) @@ -274,7 +292,7 @@ def sample(self, # entities, but need the null embeddings here. predicted_entity_embeddings = self._dynamic_embeddings.embeddings[predict_em, entity_id_predictions] concatenated = torch.cat((current_hidden[predict_em], predicted_entity_embeddings), dim=-1) - mention_length_logits = self._mention_length_projection(concatenated) + mention_length_logits = self._mention_length_projection(concatenated) / temperature mention_length_logp = F.log_softmax(mention_length_logits, dim=-1) mention_length_prediction_logp, mention_length_predictions = sample_from_logp(mention_length_logp) @@ -285,6 +303,7 @@ def sample(self, entity_ids[predict_em, timestep] = entity_id_predictions logp[predict_em] += entity_id_prediction_logp + mention_lengths[predict_em, timestep] = mention_length_predictions logp[predict_em] += mention_length_prediction_logp @@ -302,11 +321,11 @@ def sample(self, # not need to add anything to logp since these 'predictions' have probability 1 under # the model. deterministic_mask = prev_mention_lengths > 1 - deterministic_mask = deterministic_mask * mask[:, timestep].byte() - if deterministic_mask.sum() > 1: + deterministic_mask = deterministic_mask & mask[:, timestep].byte() + if deterministic_mask.any(): entity_types[deterministic_mask, timestep] = entity_types[deterministic_mask, timestep - 1] entity_ids[deterministic_mask, timestep] = entity_ids[deterministic_mask, timestep - 1] - mention_lengths[deterministic_mask, timestep] = mention_lengths[deterministic_mask, timestep - 1] - 1 + mention_lengths[deterministic_mask, timestep] = prev_mention_lengths[deterministic_mask] - 1 # Update mention lengths for next timestep prev_mention_lengths = mention_lengths[:, timestep] @@ -408,8 +427,8 @@ def _forward_loop(self, # generating a mention (e.g. if the previous remaining mention length is 1). Indexing / # masking with ``predict_all`` makes it possible to do this in batch. predict_all = prev_mention_lengths == 1 - predict_all = predict_all * mask[:, timestep].byte() - if predict_all.sum() > 0: + predict_all = predict_all & mask[:, timestep].byte() + if predict_all.any(): # Equation 3 in the paper. entity_type_logits = self._entity_type_projection(current_hidden[predict_all]) @@ -421,9 +440,9 @@ def _forward_loop(self, gold_labels=current_entity_types[predict_all].long()) # Only proceed to predict entity and mention length if there is in fact an entity. - predict_em = current_entity_types * predict_all + predict_em = current_entity_types & predict_all - if predict_em.sum() > 0: + if predict_em.any(): # Equation 4 in the paper. We want new entities to correspond to a prediction of # zero, their embedding should be added after they've been predicted for the first # time. @@ -489,7 +508,7 @@ def reset_states(self, reset: torch.ByteTensor) -> None: """Resets the model's internals. Should be called at the start of a new batch.""" if reset.any() and (self._state is not None): # Zero out any previous elements - self._state['prev_mention_lengths'][reset].zero_() + self._state['prev_mention_lengths'][reset] = 1 # Zero out the hidden state for layer in range(self._num_layers): h, c = self._state['layer_%i' % layer] @@ -511,4 +530,3 @@ def get_metrics(self, reset: bool = False) -> Dict[str, float]: 'eid_acc': self._entity_id_accuracy.get_metric(reset), 'ml_acc': self._mention_length_accuracy.get_metric(reset) } - diff --git a/kglm/models/entity_nlm.py b/kglm/models/entity_nlm.py index 6956576..63b7ffd 100644 --- a/kglm/models/entity_nlm.py +++ b/kglm/models/entity_nlm.py @@ -245,7 +245,7 @@ def _forward_loop(self, contexts = self._dummy_context_embedding.repeat(batch_size, 1) # Embed tokens and get RNN hidden state. - mask = get_text_field_mask(tokens) + mask = get_text_field_mask(tokens).byte() embeddings = self._text_field_embedder(tokens) embeddings = self._variational_dropout(embeddings) @@ -273,7 +273,7 @@ def _forward_loop(self, entity_id_loss = 0.0 mention_length_loss = 0.0 vocab_loss = 0.0 - # logp = hidden.new_zeros(batch_size) + logp = hidden.new_zeros(batch_size) # We dynamically add entities and update their representations in sequence. The following # loop is designed to imitate as closely as possible lines 219-313 in: @@ -318,14 +318,15 @@ def _forward_loop(self, # We only predict the types / ids / lengths of the next mention if we are not currently # in the process of generating it (e.g. if the current remaining mention length is 1). # Indexing / masking with ``predict_all`` makes it possible to do this in batch. - predict_all = (current_mention_lengths == 1) * next_mask.byte() - if predict_all.sum() > 0: + predict_all = (current_mention_lengths == 1) & next_mask + if predict_all.any(): # Equation 3 in the paper. entity_type_logits = self._entity_type_projection(current_hidden[predict_all]) entity_type_logp = F.log_softmax(entity_type_logits, -1) _entity_type_loss = -entity_type_logp.gather(-1, next_entity_types[predict_all].long().unsqueeze(-1)) entity_type_loss += _entity_type_loss.sum() + logp[predict_all] += -_entity_type_loss.squeeze() # entity_type_logp = torch.zeros_like(next_entity_types, dtype=torch.float32) # entity_type_logp[predict_all] = -_entity_type_loss @@ -335,8 +336,8 @@ def _forward_loop(self, gold_labels=next_entity_types[predict_all].long()) # Only proceed to predict entity and mention length if there is in fact an entity. - predict_em = next_entity_types * predict_all - if predict_em.sum() > 0: + predict_em = next_entity_types & predict_all + if predict_em.any(): # Equation 4 in the paper. entity_id_prediction_outputs = self._dynamic_embeddings(hidden=current_hidden, timestep=timestep, @@ -344,6 +345,7 @@ def _forward_loop(self, mask=predict_em) _entity_id_loss = -entity_id_prediction_outputs['loss'] entity_id_loss += _entity_id_loss.sum() + logp[predict_em] += -_entity_id_loss.squeeze() # entity_id_logp = torch.zeros_like(next_entity_ids, dtype=torch.float32) # entity_id_logp[predict_em] = -_entity_id_loss @@ -363,6 +365,7 @@ def _forward_loop(self, # next_mention_lengths[predict_em], # reduction='none') mention_length_loss += _mention_length_loss.sum() + logp[predict_em] += -_mention_length_loss.squeeze() # mention_length_logp = torch.zeros_like(next_mention_lengths, dtype=torch.float32) # mention_length_logp[predict_em] = -_mention_length_loss @@ -374,19 +377,20 @@ def _forward_loop(self, # Always predict the next word. This is done using the hidden state and contextual bias. entity_embeddings = self._dynamic_embeddings.embeddings[next_entity_types, next_entity_ids[next_entity_types]] entity_embeddings = self._entity_output_projection(entity_embeddings) - context_embeddings = contexts[1 - next_entity_types] + context_embeddings = contexts[~next_entity_types] context_embeddings = self._context_output_projection(context_embeddings) # The checks in the following block of code are required to prevent adding empty # tensors to vocab_features (which causes a floating point error). vocab_features = current_hidden.clone() - if next_entity_types.sum() > 0: + if next_entity_types.any(): vocab_features[next_entity_types] = vocab_features[next_entity_types] + entity_embeddings - if (1 - next_entity_types.sum()) > 0: - vocab_features[1 - next_entity_types] = vocab_features[1 - next_entity_types] + context_embeddings - vocab_logits = self._vocab_projection(vocab_features[next_mask.byte()]) + if (~next_entity_types).any(): + vocab_features[~next_entity_types] = vocab_features[~next_entity_types] + context_embeddings + vocab_logits = self._vocab_projection(vocab_features[next_mask]) vocab_logp = F.log_softmax(vocab_logits, -1) - _vocab_loss = -vocab_logp.gather(-1, next_tokens[next_mask.byte()].unsqueeze(-1)) + _vocab_loss = -vocab_logp.gather(-1, next_tokens[next_mask].unsqueeze(-1)) + logp[next_mask] += -_vocab_loss.squeeze() # _vocab_loss = F.cross_entropy(vocab_logits, next_tokens, reduction='none') # _vocab_loss = _vocab_loss * next_mask.float() @@ -403,8 +407,11 @@ def _forward_loop(self, # Lastly update contexts contexts = current_hidden - # Normalize the losses self._perplexity(vocab_loss, mask.sum()) + + # logp = -(entity_type_loss + entity_id_loss + mention_length_loss + vocab_loss) + + # Normalize the losses entity_type_loss /= mask.sum() logger.debug('Entity type loss: %0.4f', entity_type_loss) entity_id_loss /= mask.sum() @@ -421,7 +428,7 @@ def _forward_loop(self, 'mention_length_loss': mention_length_loss, 'vocab_loss': vocab_loss, 'loss': total_loss, - 'logp': -total_loss * mask.sum(), + 'logp': logp, 'penalized_logp': -total_loss * mask.sum() } diff --git a/kglm/modules/dynamic_embeddings.py b/kglm/modules/dynamic_embeddings.py index 1b1c4d0..3578fd3 100644 --- a/kglm/modules/dynamic_embeddings.py +++ b/kglm/modules/dynamic_embeddings.py @@ -218,7 +218,7 @@ def forward(self, # pylint: disable=arguments-differ bilinear = bilinear.view(batch_size, -1) # Second half of equation 4. - distance_score = torch.exp(self._distance_scalar * (timestep - self.last_seen[mask].float())) + distance_score = torch.exp(self._distance_scalar * (self.last_seen[mask].float() - timestep)) logits = bilinear + distance_score # Since we pre-allocate the embedding array, logits includes scores for all of the diff --git a/kglm/nn/util.py b/kglm/nn/util.py index f86a85d..f6acd6d 100644 --- a/kglm/nn/util.py +++ b/kglm/nn/util.py @@ -58,6 +58,10 @@ def sample_from_logp(logp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: cdf = torch.cumsum(pdf, dim=-1) rng = torch.rand(logp.shape[:-1], device=logp.device).unsqueeze(-1) selected_idx = cdf.lt(rng).sum(dim=-1) + # Sigh + if (selected_idx >= pdf.shape[-1]).any(): + selected_idx[selected_idx >= pdf.shape[-1]] = pdf.shape[-1] - 1 + hack = torch.ones(logp.shape[:-1], device=logp.device, dtype=torch.uint8) selected_logp = logp[hack, selected_idx[hack]] return selected_logp, selected_idx diff --git a/kglm/tests/dataset_readers/conll2012_test.py b/kglm/tests/dataset_readers/conll2012_test.py index 1b84f26..9a8156b 100644 --- a/kglm/tests/dataset_readers/conll2012_test.py +++ b/kglm/tests/dataset_readers/conll2012_test.py @@ -61,8 +61,10 @@ def test_read_from_file(self, lazy): 1, 1, 1, 1, 1, 1, 1]) class TestConll2012JsonlReader: - def test_read_from_file(self): - reader = Conll2012JsonlReader() + @pytest.mark.parametrize('lazy', (True, False)) + @pytest.mark.parametrize('offset', (0, 1)) + def test_read_from_file(self, lazy, offset): + reader = Conll2012JsonlReader(lazy=lazy, offset=offset) fixture_path = 'kglm/tests/fixtures/conll2012.jsonl' instances = ensure_list(reader.read(fixture_path)) assert len(instances) == 2 @@ -75,13 +77,13 @@ def test_read_from_file(self): second_instance_mention_lengths = instances[1]['mention_lengths'].array second_instance_entity_types = instances[1]['entity_types'].array - np.testing.assert_allclose(second_instance_entity_types[1:3], + np.testing.assert_allclose(second_instance_entity_types[(1 - offset):(3 - offset)], np.array([1,0], dtype=np.uint8)) - np.testing.assert_allclose(second_instance_entity_ids[1:2], + np.testing.assert_allclose(second_instance_entity_ids[(1 - offset):(2 - offset)], np.array([1], dtype=np.int64)) - np.testing.assert_allclose(second_instance_entity_ids[8:9], + np.testing.assert_allclose(second_instance_entity_ids[(8 - offset):(9 - offset)], np.array([1], dtype=np.int64)) - np.testing.assert_allclose(second_instance_entity_ids[30:32], + np.testing.assert_allclose(second_instance_entity_ids[(30 - offset):(32 - offset)], np.array([1, 1], dtype=np.int64)) - np.testing.assert_allclose(second_instance_mention_lengths[30:32], + np.testing.assert_allclose(second_instance_mention_lengths[(30 - offset):(32 - offset)], np.array([2, 1], dtype=np.int64)) \ No newline at end of file diff --git a/kglm/tests/nn/test_util.py b/kglm/tests/nn/test_util.py new file mode 100644 index 0000000..040f50c --- /dev/null +++ b/kglm/tests/nn/test_util.py @@ -0,0 +1,17 @@ +from flaky import flaky +import torch + +import kglm.nn.util as util + + +@flaky(max_runs=10, min_passes=1) +def test_sample_from_logp(): + # (batch_size, n_classes) + unnormalized = torch.randn(10, 30) + normalized = unnormalized / unnormalized.sum(-1, keepdim=True) + logp = torch.log(normalized) + logits_a, sample_a = util.sample_from_logp(logp) + logits_b, sample_b = util.sample_from_logp(logp) + assert not torch.allclose(logits_a, logits_b) + assert not torch.allclose(sample_a, sample_b) +