From ef6b352c9a54ba78a7725dac534903875688c126 Mon Sep 17 00:00:00 2001 From: rloganiv Date: Sun, 12 Jul 2020 19:48:31 -0700 Subject: [PATCH] ACL 2020 paper updates --- Dockerfile | 5 ++-- kglm/commands/beamsum.py | 9 ++++-- kglm/commands/evaluate_perplexity.py | 42 +++++++++++++++------------- kglm/data/__init__.py | 1 + kglm/models/__init__.py | 2 ++ kglm/models/entity_disc.py | 3 +- kglm/models/kglm_disc.py | 17 ++++++----- 7 files changed, 46 insertions(+), 33 deletions(-) diff --git a/Dockerfile b/Dockerfile index 16239fd..dbc7ddc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ -FROM python:3.6.8-jessie - +# FROM python:3.6.8-jessie +FROM pytorch/pytorch:1.1.0-cuda10.0-cudnn7.5-runtime ENV LC_ALL=C.UTF-8 ENV LANG=C.UTF-8 @@ -18,6 +18,7 @@ WORKDIR /workspace RUN chmod -R a+w /workspace COPY requirements.txt . +RUN pip install --upgrade pip RUN pip install -r requirements.txt COPY .pylintrc .pylintrc diff --git a/kglm/commands/beamsum.py b/kglm/commands/beamsum.py index 7d5b38f..9b0bbeb 100644 --- a/kglm/commands/beamsum.py +++ b/kglm/commands/beamsum.py @@ -112,9 +112,12 @@ def evaluate_perplexity(model: Model, # Draw a sample with torch.no_grad(): - sample = sampler.beam_search(source=batch['source'], - reset=batch['reset'], - k=beam_width) + # sample = sampler.beam_search(source=batch['source'], + # reset=batch['reset'], + # target=batch['target'], + # metadata=batch['metadata'], + # k=beam_width) + sample = sampler.beam_search(**batch, k=beam_width) # Evaluate on sample with torch.no_grad(): diff --git a/kglm/commands/evaluate_perplexity.py b/kglm/commands/evaluate_perplexity.py index 6d0fe25..2300776 100644 --- a/kglm/commands/evaluate_perplexity.py +++ b/kglm/commands/evaluate_perplexity.py @@ -118,13 +118,15 @@ def _chunk(x, start, stop): if isinstance(x, torch.Tensor): return x[:, start:stop].contiguous() + chunks = [] for i in range(num_splits): chunk = _chunk(batch, i * split_size, (i + 1) * split_size) if i > 0: chunk['reset'] = torch.zeros_like(chunk['reset']) - yield chunk + chunks.append(chunk) + return chunks def tile(t, amount): @@ -138,17 +140,17 @@ def tile(t, amount): return [x for x in t for _ in range(amount)] -def logsumexp(prev: torch.FloatTensor, - current: torch.FloatTensor, - i: int, - samples_per_batch: int): - # NOTE: n is number of samples - current_avg = current.view(samples_per_batch, -1).sum(dim=-1).logsumexp(dim=0) - np.log(samples_per_batch).item() - if prev is None: - return current_avg - a = torch.max(prev, current_avg) - sumexp = torch.exp(prev - a) * i / (i + 1) + torch.exp(current_avg - a) / (i + 1) - return a + torch.log(sumexp) +# def logsumexp(prev: torch.FloatTensor, +# current: torch.FloatTensor, +# i: int, +# samples_per_batch: int): +# # NOTE: n is number of samples +# current_avg = current.view(samples_per_batch, -1).sum(dim=-1).logsumexp(dim=0) - np.log(samples_per_batch).item() +# if prev is None: +# return current_avg +# a = torch.max(prev, current_avg) +# sumexp = torch.exp(prev - a) * i / (i + 1) + torch.exp(current_avg - a) / (i + 1) +# return a + torch.log(sumexp) def evaluate_perplexity(model: Model, @@ -204,7 +206,7 @@ def evaluate_perplexity(model: Model, for j, chunk in enumerate(split(batch, split_size)): generator_tqdm.set_description(f"i={i} j={j}") - chunk_tokens = util.get_text_field_mask(batch['source']).int().sum().item() + chunk_tokens = util.get_text_field_mask(batch['source']).int().sum() if chunk_tokens == 0: logger.debug('Zero chunk, skipping') continue @@ -230,11 +232,11 @@ def evaluate_perplexity(model: Model, weights = split_weights else: weights += split_weights - logger.debug(torch.exp(-split_weights/split_size)) + # logger.debug(torch.exp(-split_weights/split_size)) - epoch_weights.append(weights.cpu()) - epoch_fp.append(model_logp.view(batch_size, samples_per_batch).cpu()) - epoch_q.append(sample_logp.view(batch_size, samples_per_batch).cpu()) + epoch_weights.append(weights) #.cpu()) + epoch_fp.append(model_logp.view(batch_size, samples_per_batch))# .cpu()) + epoch_q.append(sample_logp.view(batch_size, samples_per_batch))# .cpu()) # Combine all the epoch weights combined_weights = torch.cat(epoch_weights, dim=1) @@ -251,9 +253,9 @@ def evaluate_perplexity(model: Model, logger.info(f'PPL: {torch.exp(-summand / denom)}') # Create array of all the weights - all_weights_array = torch.cat(all_weights, dim=0).numpy() - fp_array = torch.cat(fp, dim=0).numpy() - q_array = torch.cat(q, dim=0).numpy() + all_weights_array = torch.cat(all_weights, dim=0).cpu().numpy() + fp_array = torch.cat(fp, dim=0).cpu().numpy() + q_array = torch.cat(q, dim=0).cpu().numpy() # Compute perplexity ppl = torch.exp(-summand / denom) diff --git a/kglm/data/__init__.py b/kglm/data/__init__.py index 3a4cecd..b17416b 100644 --- a/kglm/data/__init__.py +++ b/kglm/data/__init__.py @@ -2,3 +2,4 @@ from .fields import SequentialArrayField from .iterators import SplitIterator from .extended_vocabulary import ExtendedVocabulary +from .dataset_readers import * diff --git a/kglm/models/__init__.py b/kglm/models/__init__.py index e69de29..e6ff16c 100644 --- a/kglm/models/__init__.py +++ b/kglm/models/__init__.py @@ -0,0 +1,2 @@ +from .entity_disc import EntityNLMDiscriminator +from .entity_nlm import EntityNLM diff --git a/kglm/models/entity_disc.py b/kglm/models/entity_disc.py index 12dd790..3893119 100644 --- a/kglm/models/entity_disc.py +++ b/kglm/models/entity_disc.py @@ -556,7 +556,8 @@ def _trace_backpointers(source: Dict[str, torch.Tensor], def beam_search(self, source: Dict[str, torch.Tensor], reset: torch.ByteTensor, - k: int) -> Tuple[torch.Tensor, torch.Tensor]: + k: int, + **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: """ Obtain the top-k (approximately) most likely predictions from the model using beam search. Unlike typical beam search all of the beam states are returned instead of just diff --git a/kglm/models/kglm_disc.py b/kglm/models/kglm_disc.py index ef81587..da9c0f8 100644 --- a/kglm/models/kglm_disc.py +++ b/kglm/models/kglm_disc.py @@ -755,8 +755,9 @@ def _next_related_entity_logp(self, next_encoded_head, next_encoded_relation, be self._recent_entities.load_beam_state(beam_state.recent_entities) for i, candidate_ids in enumerate(self._recent_entities._remaining): # Cast candidate ids to a tensor, lookup embeddings, and compute score. - candidate_ids = torch.LongTensor(list(candidate_ids.keys()), - device=next_encoded_head.device) + candidate_ids = torch.tensor(list(candidate_ids.keys()), + dtype=torch.int64, + device=next_encoded_head.device) candidate_embeddings = self._entity_embedder(candidate_ids) candidate_logits = torch.mv(candidate_embeddings, next_encoded_head[i]) candidate_logp = F.log_softmax(candidate_logits) @@ -767,10 +768,10 @@ def _next_related_entity_logp(self, next_encoded_head, next_encoded_relation, be # Stop early if node is isolated if not s: - logp_arr[i, j] = torch.FloatTensor([], device=next_encoded_head.device) - parent_ids_arr[i, j] = torch.LongTensor([], device=next_encoded_head.device) - relations_arr[i, j] = torch.LongTensor([], device=next_encoded_head.device) - raw_entity_ids_arr[i, j] = torch.LongTensor([], device=next_encoded_head.device) + logp_arr[i, j] = torch.tensor([], dtype=torch.float32, device=next_encoded_head.device) + parent_ids_arr[i, j] = torch.tensor([], dtype=torch.int64, device=next_encoded_head.device) + relations_arr[i, j] = torch.tensor([], dtype=torch.int64, device=next_encoded_head.device) + raw_entity_ids_arr[i, j] = torch.tensor([], dtype=torch.int64, device=next_encoded_head.device) continue # Otherwise compute relation probabilities for each parent and combine @@ -1020,7 +1021,8 @@ def beam_search(self, target: Dict[str, torch.Tensor], reset: torch.ByteTensor, metadata: Dict[str, Any], - k: int) -> Tuple[torch.Tensor, torch.Tensor]: + k: int, + **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: """ Obtain the top-k (approximately) most likely predictions from the model using beam search. Unlike typical beam search all of the beam states are returned instead of just @@ -1084,6 +1086,7 @@ def beam_search(self, output = None for timestep in range(sequence_length): + logger.debug(timestep) # Get log probabilities of all next states next_mention_type_logp = self._next_mention_type_logp(mention_type_logits[:, timestep], beam_states)