Skip to content

Commit

Permalink
ACL 2020 paper updates
Browse files Browse the repository at this point in the history
  • Loading branch information
rloganiv committed Jul 13, 2020
1 parent 453189b commit ef6b352
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 33 deletions.
5 changes: 3 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
FROM python:3.6.8-jessie

# FROM python:3.6.8-jessie
FROM pytorch/pytorch:1.1.0-cuda10.0-cudnn7.5-runtime
ENV LC_ALL=C.UTF-8
ENV LANG=C.UTF-8

Expand All @@ -18,6 +18,7 @@ WORKDIR /workspace
RUN chmod -R a+w /workspace

COPY requirements.txt .
RUN pip install --upgrade pip
RUN pip install -r requirements.txt

COPY .pylintrc .pylintrc
Expand Down
9 changes: 6 additions & 3 deletions kglm/commands/beamsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,12 @@ def evaluate_perplexity(model: Model,

# Draw a sample
with torch.no_grad():
sample = sampler.beam_search(source=batch['source'],
reset=batch['reset'],
k=beam_width)
# sample = sampler.beam_search(source=batch['source'],
# reset=batch['reset'],
# target=batch['target'],
# metadata=batch['metadata'],
# k=beam_width)
sample = sampler.beam_search(**batch, k=beam_width)

# Evaluate on sample
with torch.no_grad():
Expand Down
42 changes: 22 additions & 20 deletions kglm/commands/evaluate_perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,15 @@ def _chunk(x, start, stop):
if isinstance(x, torch.Tensor):
return x[:, start:stop].contiguous()

chunks = []
for i in range(num_splits):
chunk = _chunk(batch, i * split_size, (i + 1) * split_size)

if i > 0:
chunk['reset'] = torch.zeros_like(chunk['reset'])

yield chunk
chunks.append(chunk)
return chunks


def tile(t, amount):
Expand All @@ -138,17 +140,17 @@ def tile(t, amount):
return [x for x in t for _ in range(amount)]


def logsumexp(prev: torch.FloatTensor,
current: torch.FloatTensor,
i: int,
samples_per_batch: int):
# NOTE: n is number of samples
current_avg = current.view(samples_per_batch, -1).sum(dim=-1).logsumexp(dim=0) - np.log(samples_per_batch).item()
if prev is None:
return current_avg
a = torch.max(prev, current_avg)
sumexp = torch.exp(prev - a) * i / (i + 1) + torch.exp(current_avg - a) / (i + 1)
return a + torch.log(sumexp)
# def logsumexp(prev: torch.FloatTensor,
# current: torch.FloatTensor,
# i: int,
# samples_per_batch: int):
# # NOTE: n is number of samples
# current_avg = current.view(samples_per_batch, -1).sum(dim=-1).logsumexp(dim=0) - np.log(samples_per_batch).item()
# if prev is None:
# return current_avg
# a = torch.max(prev, current_avg)
# sumexp = torch.exp(prev - a) * i / (i + 1) + torch.exp(current_avg - a) / (i + 1)
# return a + torch.log(sumexp)


def evaluate_perplexity(model: Model,
Expand Down Expand Up @@ -204,7 +206,7 @@ def evaluate_perplexity(model: Model,
for j, chunk in enumerate(split(batch, split_size)):
generator_tqdm.set_description(f"i={i} j={j}")

chunk_tokens = util.get_text_field_mask(batch['source']).int().sum().item()
chunk_tokens = util.get_text_field_mask(batch['source']).int().sum()
if chunk_tokens == 0:
logger.debug('Zero chunk, skipping')
continue
Expand All @@ -230,11 +232,11 @@ def evaluate_perplexity(model: Model,
weights = split_weights
else:
weights += split_weights
logger.debug(torch.exp(-split_weights/split_size))
# logger.debug(torch.exp(-split_weights/split_size))

epoch_weights.append(weights.cpu())
epoch_fp.append(model_logp.view(batch_size, samples_per_batch).cpu())
epoch_q.append(sample_logp.view(batch_size, samples_per_batch).cpu())
epoch_weights.append(weights) #.cpu())
epoch_fp.append(model_logp.view(batch_size, samples_per_batch))# .cpu())
epoch_q.append(sample_logp.view(batch_size, samples_per_batch))# .cpu())

# Combine all the epoch weights
combined_weights = torch.cat(epoch_weights, dim=1)
Expand All @@ -251,9 +253,9 @@ def evaluate_perplexity(model: Model,
logger.info(f'PPL: {torch.exp(-summand / denom)}')

# Create array of all the weights
all_weights_array = torch.cat(all_weights, dim=0).numpy()
fp_array = torch.cat(fp, dim=0).numpy()
q_array = torch.cat(q, dim=0).numpy()
all_weights_array = torch.cat(all_weights, dim=0).cpu().numpy()
fp_array = torch.cat(fp, dim=0).cpu().numpy()
q_array = torch.cat(q, dim=0).cpu().numpy()

# Compute perplexity
ppl = torch.exp(-summand / denom)
Expand Down
1 change: 1 addition & 0 deletions kglm/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .fields import SequentialArrayField
from .iterators import SplitIterator
from .extended_vocabulary import ExtendedVocabulary
from .dataset_readers import *
2 changes: 2 additions & 0 deletions kglm/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .entity_disc import EntityNLMDiscriminator
from .entity_nlm import EntityNLM
3 changes: 2 additions & 1 deletion kglm/models/entity_disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,8 @@ def _trace_backpointers(source: Dict[str, torch.Tensor],
def beam_search(self,
source: Dict[str, torch.Tensor],
reset: torch.ByteTensor,
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
k: int,
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Obtain the top-k (approximately) most likely predictions from the model using beam
search. Unlike typical beam search all of the beam states are returned instead of just
Expand Down
17 changes: 10 additions & 7 deletions kglm/models/kglm_disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,8 +755,9 @@ def _next_related_entity_logp(self, next_encoded_head, next_encoded_relation, be
self._recent_entities.load_beam_state(beam_state.recent_entities)
for i, candidate_ids in enumerate(self._recent_entities._remaining):
# Cast candidate ids to a tensor, lookup embeddings, and compute score.
candidate_ids = torch.LongTensor(list(candidate_ids.keys()),
device=next_encoded_head.device)
candidate_ids = torch.tensor(list(candidate_ids.keys()),
dtype=torch.int64,
device=next_encoded_head.device)
candidate_embeddings = self._entity_embedder(candidate_ids)
candidate_logits = torch.mv(candidate_embeddings, next_encoded_head[i])
candidate_logp = F.log_softmax(candidate_logits)
Expand All @@ -767,10 +768,10 @@ def _next_related_entity_logp(self, next_encoded_head, next_encoded_relation, be

# Stop early if node is isolated
if not s:
logp_arr[i, j] = torch.FloatTensor([], device=next_encoded_head.device)
parent_ids_arr[i, j] = torch.LongTensor([], device=next_encoded_head.device)
relations_arr[i, j] = torch.LongTensor([], device=next_encoded_head.device)
raw_entity_ids_arr[i, j] = torch.LongTensor([], device=next_encoded_head.device)
logp_arr[i, j] = torch.tensor([], dtype=torch.float32, device=next_encoded_head.device)
parent_ids_arr[i, j] = torch.tensor([], dtype=torch.int64, device=next_encoded_head.device)
relations_arr[i, j] = torch.tensor([], dtype=torch.int64, device=next_encoded_head.device)
raw_entity_ids_arr[i, j] = torch.tensor([], dtype=torch.int64, device=next_encoded_head.device)
continue

# Otherwise compute relation probabilities for each parent and combine
Expand Down Expand Up @@ -1020,7 +1021,8 @@ def beam_search(self,
target: Dict[str, torch.Tensor],
reset: torch.ByteTensor,
metadata: Dict[str, Any],
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
k: int,
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Obtain the top-k (approximately) most likely predictions from the model using beam
search. Unlike typical beam search all of the beam states are returned instead of just
Expand Down Expand Up @@ -1084,6 +1086,7 @@ def beam_search(self,
output = None

for timestep in range(sequence_length):
logger.debug(timestep)
# Get log probabilities of all next states
next_mention_type_logp = self._next_mention_type_logp(mention_type_logits[:, timestep],
beam_states)
Expand Down

0 comments on commit ef6b352

Please sign in to comment.