From 9151eabd0deec5316d0f56ae39d021e320a2e9bb Mon Sep 17 00:00:00 2001 From: rloganiv Date: Tue, 15 Oct 2019 16:21:45 -0700 Subject: [PATCH] Improved sampling efficiency --- experiments/entity_disc_conll2012.jsonnet | 21 ++- .../entity_disc_conll2012_mini.jsonnet | 70 ++++++++ .../entity_disc_conll2012_no_peeking.jsonnet | 71 ++++++++ ...ity_disc_conll2012_no_peeking_mini.jsonnet | 71 ++++++++ experiments/entity_disc_test.jsonnet | 4 +- experiments/entity_nlm.jsonnet | 14 +- experiments/entity_nlm_conll2012.jsonnet | 23 ++- experiments/entity_nlm_conll2012_mini.jsonnet | 71 ++++++++ kglm/commands/evaluate_perplexity.py | 152 +++++++++++++----- kglm/models/entity_disc.py | 20 ++- kglm/models/entity_nlm.py | 10 +- kglm/modules/dynamic_embeddings.py | 2 +- kglm/nn/util.py | 4 + kglm/tests/nn/test_util.py | 17 ++ 14 files changed, 487 insertions(+), 63 deletions(-) create mode 100644 experiments/entity_disc_conll2012_mini.jsonnet create mode 100644 experiments/entity_disc_conll2012_no_peeking.jsonnet create mode 100644 experiments/entity_disc_conll2012_no_peeking_mini.jsonnet create mode 100644 experiments/entity_nlm_conll2012_mini.jsonnet create mode 100644 kglm/tests/nn/test_util.py diff --git a/experiments/entity_disc_conll2012.jsonnet b/experiments/entity_disc_conll2012.jsonnet index 72bb063..2cee49d 100644 --- a/experiments/entity_disc_conll2012.jsonnet +++ b/experiments/entity_disc_conll2012.jsonnet @@ -2,7 +2,7 @@ "vocabulary": { "type": "extended", "extend": false, - "directory_path": "results/entity-nlm-conll2012-fixed/vocabulary" + "directory_path": "data/vocabulary" }, "dataset_reader": { "type": "conll2012_jsonl", @@ -36,7 +36,7 @@ }, "iterator": { "type": "fancy", - "batch_size": 16, + "batch_size": 343, "split_size": 30, "splitting_keys": [ "source", @@ -45,13 +45,26 @@ "mention_lengths" ], }, + "validation_iterator": { + "type": "fancy", + "batch_size": 343, + "split_size": 128, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + "truncate": false + }, "trainer": { "type": "lm", - "num_epochs": 40, + "num_epochs": 400, "cuda_device": 0, "optimizer": { "type": "adam", "lr": 1e-3 - } + }, + "validation_metric": "+eid_acc" } } diff --git a/experiments/entity_disc_conll2012_mini.jsonnet b/experiments/entity_disc_conll2012_mini.jsonnet new file mode 100644 index 0000000..5736946 --- /dev/null +++ b/experiments/entity_disc_conll2012_mini.jsonnet @@ -0,0 +1,70 @@ +{ + "vocabulary": { + "type": "extended", + "extend": false, + "directory_path": "data/vocabulary-mini" + }, + "dataset_reader": { + "type": "conll2012_jsonl", + "token_indexers": { + "tokens": { + "type": "single_id", + "lowercase_tokens": true + } + } + }, + "train_data_path": "data/conll-2012/processed/train-mini.jsonl", + "validation_data_path": "data/conll-2012/processed/dev-mini.jsonl", + "model": { + "type": "entitydisc", + "text_field_embedder": { + "token_embedders": { + "tokens": { + "type": "embedding", + "embedding_dim": 256, + "trainable": true + }, + }, + }, + "embedding_dim": 256, + "hidden_size": 256, + "num_layers": 1, + "max_mention_length": 50, + "max_embeddings": 10, + "dropout_rate": 0.4, + "variational_dropout_rate": 0.1 + }, + "iterator": { + "type": "fancy", + "batch_size": 16, + "split_size": 15, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + }, + "validation_iterator": { + "type": "fancy", + "batch_size": 16, + "split_size": 15, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + "truncate": false + }, + "trainer": { + "type": "lm", + "num_epochs": 400, + "cuda_device": 0, + "optimizer": { + "type": "adam", + "lr": 1e-3 + }, + "validation_metric": "+eid_acc" + } +} diff --git a/experiments/entity_disc_conll2012_no_peeking.jsonnet b/experiments/entity_disc_conll2012_no_peeking.jsonnet new file mode 100644 index 0000000..5b9b5bf --- /dev/null +++ b/experiments/entity_disc_conll2012_no_peeking.jsonnet @@ -0,0 +1,71 @@ +{ + "vocabulary": { + "type": "extended", + "extend": false, + "directory_path": "data/vocabulary" + }, + "dataset_reader": { + "type": "conll2012_jsonl", + "offset": 1, + "token_indexers": { + "tokens": { + "type": "single_id", + "lowercase_tokens": true + } + } + }, + "train_data_path": "data/conll-2012/processed/train.jsonl", + "validation_data_path": "data/conll-2012/processed/dev.jsonl", + "model": { + "type": "entitydisc", + "text_field_embedder": { + "token_embedders": { + "tokens": { + "type": "embedding", + "embedding_dim": 128, + "trainable": true + }, + }, + }, + "embedding_dim": 128, + "hidden_size": 128, + "num_layers": 1, + "max_mention_length": 100, + "max_embeddings": 100, + "dropout_rate": 0.4, + "variational_dropout_rate": 0.1 + }, + "iterator": { + "type": "fancy", + "batch_size": 16, + "split_size": 15, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + }, + "validation_iterator": { + "type": "fancy", + "batch_size": 16, + "split_size": 15, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + "truncate": false + }, + "trainer": { + "type": "lm", + "num_epochs": 400, + "cuda_device": 0, + "optimizer": { + "type": "adam", + "lr": 1e-4 + }, + "validation_metric": "+eid_acc" + } +} diff --git a/experiments/entity_disc_conll2012_no_peeking_mini.jsonnet b/experiments/entity_disc_conll2012_no_peeking_mini.jsonnet new file mode 100644 index 0000000..046f380 --- /dev/null +++ b/experiments/entity_disc_conll2012_no_peeking_mini.jsonnet @@ -0,0 +1,71 @@ +{ + "vocabulary": { + "type": "extended", + "extend": false, + "directory_path": "data/vocabulary-mini" + }, + "dataset_reader": { + "type": "conll2012_jsonl", + "offset": 1, + "token_indexers": { + "tokens": { + "type": "single_id", + "lowercase_tokens": true + } + } + }, + "train_data_path": "data/conll-2012/processed/train-mini.jsonl", + "validation_data_path": "data/conll-2012/processed/dev-mini.jsonl", + "model": { + "type": "entitydisc", + "text_field_embedder": { + "token_embedders": { + "tokens": { + "type": "embedding", + "embedding_dim": 128, + "trainable": true + }, + }, + }, + "embedding_dim": 128, + "hidden_size": 128, + "num_layers": 1, + "max_mention_length": 50, + "max_embeddings": 10, + "dropout_rate": 0.4, + "variational_dropout_rate": 0.1 + }, + "iterator": { + "type": "fancy", + "batch_size": 16, + "split_size": 15, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + }, + "validation_iterator": { + "type": "fancy", + "batch_size": 16, + "split_size": 15, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + "truncate": false + }, + "trainer": { + "type": "lm", + "num_epochs": 400, + "cuda_device": 0, + "optimizer": { + "type": "adam", + "lr": 1e-4 + }, + "validation_metric": "+eid_acc" + } +} diff --git a/experiments/entity_disc_test.jsonnet b/experiments/entity_disc_test.jsonnet index 4c9f6e8..ed932e0 100644 --- a/experiments/entity_disc_test.jsonnet +++ b/experiments/entity_disc_test.jsonnet @@ -2,7 +2,7 @@ "vocabulary": { "type": "extended", "extend": false, - "directory_path": "./results/entity-nlm-wt2.fixed-vocab.dropout.2/vocabulary" + "directory_path": "data/vocabulary" }, "dataset_reader": { "type": "enhanced-wikitext", @@ -59,4 +59,4 @@ "lr": 1e-3 } } -} \ No newline at end of file +} diff --git a/experiments/entity_nlm.jsonnet b/experiments/entity_nlm.jsonnet index f2a3eca..9bfc403 100644 --- a/experiments/entity_nlm.jsonnet +++ b/experiments/entity_nlm.jsonnet @@ -4,7 +4,7 @@ }, "iterator": { "type": "fancy", - "batch_size": 60, + "batch_size": 30, "split_size": 70, "splitting_keys": [ "source", @@ -13,6 +13,18 @@ "mention_lengths" ] }, + "validation_iterator": { + "type": "fancy", + "batch_size": 30, + "split_size": 70, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + "truncate": false + }, "model": { "type": "entitynlm", "dropout_rate": 0.5, diff --git a/experiments/entity_nlm_conll2012.jsonnet b/experiments/entity_nlm_conll2012.jsonnet index 5437e38..609151b 100644 --- a/experiments/entity_nlm_conll2012.jsonnet +++ b/experiments/entity_nlm_conll2012.jsonnet @@ -1,11 +1,8 @@ { "vocabulary": { "type": "extended", - "max_vocab_size": { - // This does not count the @@UNKNOWN@@ token, which - // ends up being our 10,000th token. - "tokens": 9999 - } + "extend": false, + "directory_path": "data/vocabulary" }, "dataset_reader": { "type": "conll2012_jsonl", @@ -41,14 +38,26 @@ }, "iterator": { "type": "fancy", - "batch_size": 16, - "split_size": 30, + "batch_size": 512, + "split_size": 15, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + }, + "validation_iterator": { + "type": "fancy", + "batch_size": 512, + "split_size": 15, "splitting_keys": [ "source", "entity_types", "entity_ids", "mention_lengths" ], + "truncate": false }, "trainer": { "type": "lm", diff --git a/experiments/entity_nlm_conll2012_mini.jsonnet b/experiments/entity_nlm_conll2012_mini.jsonnet new file mode 100644 index 0000000..e35a67b --- /dev/null +++ b/experiments/entity_nlm_conll2012_mini.jsonnet @@ -0,0 +1,71 @@ +{ + "vocabulary": { + "type": "extended", + "extend": false, + "directory_path": "data/vocabulary-mini" + }, + "dataset_reader": { + "type": "conll2012_jsonl", + "token_indexers": { + "tokens": { + "type": "single_id", + "lowercase_tokens": true + } + } + }, + "train_data_path": "data/conll-2012/processed/train-mini.jsonl", + "validation_data_path": "data/conll-2012/processed/dev-mini.jsonl", + "datasets_for_vocab_creation": ["train"], + "model": { + "type": "entitynlm", + "text_field_embedder": { + "token_embedders": { + "tokens": { + "type": "embedding", + "embedding_dim": 256, + "trainable": true + }, + }, + }, + "embedding_dim": 256, + "hidden_size": 256, + "num_layers": 1, + "max_mention_length": 50, + "max_embeddings": 10, + "tie_weights": true, + "dropout_rate": 0.4, + "variational_dropout_rate": 0.1 + }, + "iterator": { + "type": "fancy", + "batch_size": 16, + "split_size": 15, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + }, + "validation_iterator": { + "type": "fancy", + "batch_size": 16, + "split_size": 15, + "splitting_keys": [ + "source", + "entity_types", + "entity_ids", + "mention_lengths" + ], + "truncate": false + }, + "trainer": { + "type": "lm", + "num_epochs": 400, + "cuda_device": 0, + "optimizer": { + "type": "adam", + "lr": 1e-3 + } + } +} diff --git a/kglm/commands/evaluate_perplexity.py b/kglm/commands/evaluate_perplexity.py index 079fb56..21dfeeb 100644 --- a/kglm/commands/evaluate_perplexity.py +++ b/kglm/commands/evaluate_perplexity.py @@ -63,7 +63,12 @@ def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argpar subparser.add_argument('--num-samples', type=int, - default=100, + default=10000, + help='Number of importance samples to draw.') + + subparser.add_argument('--samples-per-batch', + type=int, + default=1, help='Number of importance samples to draw.') subparser.add_argument('--temperature', @@ -98,6 +103,28 @@ def _offset(sample, held_over_data): return new_sample, new_held_over_data +def tile(t, amount): + if isinstance(t, torch.Tensor): + args = [1 for _ in t.shape] + args[0] = amount + return t.repeat(*args) + elif isinstance(t, dict): + return {k: tile(v, amount) for k, v in t.items()} + + +def logsumexp(prev: torch.FloatTensor, + current: torch.FloatTensor, + i: int, + samples_per_batch: int): + # NOTE: n is number of samples + current_avg = current.view(samples_per_batch, -1).sum(dim=-1).logsumexp(dim=0) - np.log(samples_per_batch).item() + if prev is None: + return current_avg + a = torch.max(prev, current_avg) + sumexp = torch.exp(prev - a) * i / (i + 1) + torch.exp(current_avg - a) / (i + 1) + return a + torch.log(sumexp) + + def evaluate_perplexity(model: Model, sampler: Model, num_samples: int, @@ -105,40 +132,57 @@ def evaluate_perplexity(model: Model, data_iterator: DataIterator, cuda_device: int, temperature: float = 1.0, - offset: bool = False) -> Dict[str, Any]: + offset: bool = False, + samples_per_batch: int = 1) -> Dict[str, Any]: check_for_gpu(cuda_device) logger.info('Iterating over dataset') - summands = [] - penalized_summands = [] - trajectory = np.zeros(num_samples) - individual_estimates = np.zeros(num_samples) + # summands = [] + # penalized_summands = [] + trajectory = np.zeros(num_samples // samples_per_batch) + individual_estimates = np.zeros(num_samples // samples_per_batch) + s_probs = np.zeros((348, num_samples // samples_per_batch)) + - for i in range(num_samples): + weight = None + + for i in range(num_samples // samples_per_batch): iterator = data_iterator(instances, num_epochs=1, shuffle=False) generator_tqdm = Tqdm.tqdm(iterator, total=0) model.eval() sampler.eval() + sampler._state = None - summand = torch.tensor(0.0) - penalized_summand = torch.tensor(0.0) - denom = 0 + summand = None + denom = None + #summand = torch.tensor(0.0) + # penalized_summand = torch.tensor(0.0) held_over_data = None for batch, _ in generator_tqdm: + # We need sequence length to help compute perplexity + n_tokens = util.get_text_field_mask(batch['source']).float().sum(dim=-1) + if denom is None: + denom = n_tokens + else: + denom += n_tokens + + summand = util.move_to_device(summand, cuda_device) batch = util.move_to_device(batch, cuda_device) - # We need sequence length to help compute perplexity - n_tokens = util.get_text_field_mask(batch['source']).float().sum().item() - denom += n_tokens + # Tile if that's what we're doing + if samples_per_batch > 1: + batch = tile(batch, samples_per_batch) # Draw a sample with torch.no_grad(): - sampler_output = sampler.sample(**batch, temperature=temperature) + sampler_output = sampler.sample(**batch, + temperature=temperature, + offset=offset) sample_logp = sampler_output['logp'] sample = sampler_output['sample'] @@ -150,35 +194,56 @@ def evaluate_perplexity(model: Model, model_output = model(**sample) model_logp = model_output['logp'] - model_penalized_logp = model_output['penalized_logp'] - summand += (model_logp - sample_logp) - penalized_summand += (model_penalized_logp - sample_logp) + if summand is None: + summand = (model_logp - sample_logp) + else: + summand += (model_logp - sample_logp) + + # model_penalized_logp = model_output['penalized_logp'] + # penalized_summand += (model_penalized_logp - sample_logp) # generator_tqdm.set_description('Instantaneous PPL: %0.4f' % torch.exp((sample_logp - model_logp) / n_tokens).item()) - summands.append(summand) - penalized_summands.append(penalized_summand) - if i == 0: - t = summand.unsqueeze(0) - p = penalized_summand.unsqueeze(0) - else: - t = torch.stack(summands, dim=0) - p = torch.stack(penalized_summands, dim=0) - t_sum = torch.logsumexp(t, dim=0) - p_sum = torch.logsumexp(p, dim=0) - sum_logp = (t_sum - math.log(i+1)).item() - sum_logp_penalized = (p_sum - math.log(i+1)).item() - ppl = math.exp(-sum_logp / denom) - upp = math.exp(-sum_logp_penalized / denom) - - trajectory[i] = ppl - individual_estimates[i] = math.exp(-summand.item() / denom) - - print('PPL: %f' % ppl) - print('UPP: %f' % upp) - - metrics = {'ppl': ppl, 'upp': upp, 'trajectory': trajectory, - 'individual_estimates': individual_estimates} + + current_avg = summand.view(samples_per_batch, -1).sum(dim=-1).logsumexp(dim=0) - np.log(samples_per_batch).item() + instance_ppl = torch.exp(-current_avg.sum() / denom.sum()) + + weight = logsumexp(weight, summand, i, samples_per_batch) + ppl = torch.exp(-weight / denom.sum()) + + individual_estimates[i] = instance_ppl.item() + trajectory[i] = ppl.item() + + s_probs[:, i] = torch.exp(-summand.cpu() / denom.cpu()).numpy() + # summands.append(summand) + # # penalized_summands.append(penalized_summand) + # # if i == 0: + # # t = summand.unsqueeze(0) + # # p = penalized_summand.unsqueeze(0) + # # else: + # # t = torch.stack(summands, dim=0) + # # # p = torch.stack(penalized_summands, dim=0) + # t = torch.cat(summands, dim=0) + # t_sum = torch.logsumexp(t, dim=0) + # # p_sum = torch.logsumexp(p, dim=0) + # sum_logp = (t_sum - math.log((i+1)*1000)).item() + # # sum_logp_penalized = (p_sum - math.log((i+1)*1000)).item() + # ppl = math.exp(-sum_logp / 659) + # # upp = math.exp(-sum_logp_penalized / denom) + + # trajectory[i] = ppl + # # individual_estimates[i] = math.exp(-summand.item() / denom) + + # print('PPL: %f' % ppl) + # # print('UPP: %f' % upp) + + metrics = { + 'ppl': ppl, + # 'upp': upp, + 'trajectory': trajectory, + 'individual_estimates': individual_estimates, + 's_probs': s_probs + } return metrics def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]: @@ -221,7 +286,7 @@ def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]: iterator.index_with(model.vocab) metrics = evaluate_perplexity(model, sampler, args.num_samples, instances, iterator, args.cuda_device, args.temperature, - args.offset) + args.offset, args.samples_per_batch) logger.info('Finished evaluating.') logger.info('Metrics:') @@ -230,7 +295,8 @@ def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]: output_file = args.output_file if output_file: - np.save(output_file + 'trajectory.npy', metrics['trajectory']) - np.save(output_file + 'individual_estimates.npy', metrics['individual_estimates']) + np.save(output_file + '.trajectory.npy', metrics['trajectory']) + np.save(output_file + '.individual_estimates.npy', metrics['individual_estimates']) + np.save(output_file + '.s_probs.npy', metrics['s_probs']) return metrics diff --git a/kglm/models/entity_disc.py b/kglm/models/entity_disc.py index 7faa616..692becd 100644 --- a/kglm/models/entity_disc.py +++ b/kglm/models/entity_disc.py @@ -74,6 +74,8 @@ def __init__(self, self._num_layers = num_layers self._max_mention_length = max_mention_length self._max_embeddings = max_embeddings + self._sos_token = self.vocab.get_token_index('@@START@@', 'tokens') + self._eos_token = self.vocab.get_token_index('@@END@@', 'tokens') # Rnn Encoders. rnns: List[torch.nn.Module] = [] @@ -169,6 +171,7 @@ def sample(self, source: Dict[str, torch.Tensor], reset: torch.ByteTensor = None, temperature: float = 1.0, + offset: bool = False, **kwargs) -> Dict[str, torch.Tensor]: """ Generates a sample from the discriminative model. @@ -211,6 +214,19 @@ def sample(self, # Embed tokens and get RNN hidden state. mask = get_text_field_mask(source) + + # If not offsetting, we need to ignore contribution of @@START@@ token annotation since it + # is never used. + + if not offset: + sos_mask = source['tokens'].ne(self._sos_token) + mask = mask.byte() & sos_mask + # If offsetting, we need to ignore contribution of @@END@@ token annotation since it is + # never used. + if offset: + eos_mask = source['tokens'].ne(self._eos_token) + mask = mask.byte() & eos_mask + embeddings = self._text_field_embedder(source) current_input = embeddings hidden_list = [] @@ -318,7 +334,7 @@ def sample(self, self._state['prev_mention_lengths'] = prev_mention_lengths.detach() return { - 'logp': logp.sum(), + 'logp': logp, 'sample': { 'source': source, 'reset': reset, @@ -492,7 +508,7 @@ def reset_states(self, reset: torch.ByteTensor) -> None: """Resets the model's internals. Should be called at the start of a new batch.""" if reset.any() and (self._state is not None): # Zero out any previous elements - self._state['prev_mention_lengths'][reset].zero_() + self._state['prev_mention_lengths'][reset] = 1 # Zero out the hidden state for layer in range(self._num_layers): h, c = self._state['layer_%i' % layer] diff --git a/kglm/models/entity_nlm.py b/kglm/models/entity_nlm.py index d091ae9..63b7ffd 100644 --- a/kglm/models/entity_nlm.py +++ b/kglm/models/entity_nlm.py @@ -273,7 +273,7 @@ def _forward_loop(self, entity_id_loss = 0.0 mention_length_loss = 0.0 vocab_loss = 0.0 - # logp = hidden.new_zeros(batch_size) + logp = hidden.new_zeros(batch_size) # We dynamically add entities and update their representations in sequence. The following # loop is designed to imitate as closely as possible lines 219-313 in: @@ -326,6 +326,7 @@ def _forward_loop(self, entity_type_logp = F.log_softmax(entity_type_logits, -1) _entity_type_loss = -entity_type_logp.gather(-1, next_entity_types[predict_all].long().unsqueeze(-1)) entity_type_loss += _entity_type_loss.sum() + logp[predict_all] += -_entity_type_loss.squeeze() # entity_type_logp = torch.zeros_like(next_entity_types, dtype=torch.float32) # entity_type_logp[predict_all] = -_entity_type_loss @@ -344,6 +345,7 @@ def _forward_loop(self, mask=predict_em) _entity_id_loss = -entity_id_prediction_outputs['loss'] entity_id_loss += _entity_id_loss.sum() + logp[predict_em] += -_entity_id_loss.squeeze() # entity_id_logp = torch.zeros_like(next_entity_ids, dtype=torch.float32) # entity_id_logp[predict_em] = -_entity_id_loss @@ -363,6 +365,7 @@ def _forward_loop(self, # next_mention_lengths[predict_em], # reduction='none') mention_length_loss += _mention_length_loss.sum() + logp[predict_em] += -_mention_length_loss.squeeze() # mention_length_logp = torch.zeros_like(next_mention_lengths, dtype=torch.float32) # mention_length_logp[predict_em] = -_mention_length_loss @@ -374,7 +377,7 @@ def _forward_loop(self, # Always predict the next word. This is done using the hidden state and contextual bias. entity_embeddings = self._dynamic_embeddings.embeddings[next_entity_types, next_entity_ids[next_entity_types]] entity_embeddings = self._entity_output_projection(entity_embeddings) - context_embeddings = contexts[1 - next_entity_types] + context_embeddings = contexts[~next_entity_types] context_embeddings = self._context_output_projection(context_embeddings) # The checks in the following block of code are required to prevent adding empty @@ -387,6 +390,7 @@ def _forward_loop(self, vocab_logits = self._vocab_projection(vocab_features[next_mask]) vocab_logp = F.log_softmax(vocab_logits, -1) _vocab_loss = -vocab_logp.gather(-1, next_tokens[next_mask].unsqueeze(-1)) + logp[next_mask] += -_vocab_loss.squeeze() # _vocab_loss = F.cross_entropy(vocab_logits, next_tokens, reduction='none') # _vocab_loss = _vocab_loss * next_mask.float() @@ -405,7 +409,7 @@ def _forward_loop(self, self._perplexity(vocab_loss, mask.sum()) - logp = -(entity_type_loss + entity_id_loss + mention_length_loss + vocab_loss) + # logp = -(entity_type_loss + entity_id_loss + mention_length_loss + vocab_loss) # Normalize the losses entity_type_loss /= mask.sum() diff --git a/kglm/modules/dynamic_embeddings.py b/kglm/modules/dynamic_embeddings.py index 1b1c4d0..3578fd3 100644 --- a/kglm/modules/dynamic_embeddings.py +++ b/kglm/modules/dynamic_embeddings.py @@ -218,7 +218,7 @@ def forward(self, # pylint: disable=arguments-differ bilinear = bilinear.view(batch_size, -1) # Second half of equation 4. - distance_score = torch.exp(self._distance_scalar * (timestep - self.last_seen[mask].float())) + distance_score = torch.exp(self._distance_scalar * (self.last_seen[mask].float() - timestep)) logits = bilinear + distance_score # Since we pre-allocate the embedding array, logits includes scores for all of the diff --git a/kglm/nn/util.py b/kglm/nn/util.py index f86a85d..f6acd6d 100644 --- a/kglm/nn/util.py +++ b/kglm/nn/util.py @@ -58,6 +58,10 @@ def sample_from_logp(logp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: cdf = torch.cumsum(pdf, dim=-1) rng = torch.rand(logp.shape[:-1], device=logp.device).unsqueeze(-1) selected_idx = cdf.lt(rng).sum(dim=-1) + # Sigh + if (selected_idx >= pdf.shape[-1]).any(): + selected_idx[selected_idx >= pdf.shape[-1]] = pdf.shape[-1] - 1 + hack = torch.ones(logp.shape[:-1], device=logp.device, dtype=torch.uint8) selected_logp = logp[hack, selected_idx[hack]] return selected_logp, selected_idx diff --git a/kglm/tests/nn/test_util.py b/kglm/tests/nn/test_util.py new file mode 100644 index 0000000..040f50c --- /dev/null +++ b/kglm/tests/nn/test_util.py @@ -0,0 +1,17 @@ +from flaky import flaky +import torch + +import kglm.nn.util as util + + +@flaky(max_runs=10, min_passes=1) +def test_sample_from_logp(): + # (batch_size, n_classes) + unnormalized = torch.randn(10, 30) + normalized = unnormalized / unnormalized.sum(-1, keepdim=True) + logp = torch.log(normalized) + logits_a, sample_a = util.sample_from_logp(logp) + logits_b, sample_b = util.sample_from_logp(logp) + assert not torch.allclose(logits_a, logits_b) + assert not torch.allclose(sample_a, sample_b) +