Skip to content

Commit

Permalink
Cleaner importance sampling logging + added offset option to KglmDisc
Browse files Browse the repository at this point in the history
  • Loading branch information
rloganiv committed Dec 9, 2019
1 parent 39303d5 commit 453189b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
21 changes: 14 additions & 7 deletions kglm/commands/evaluate_perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argpar

subparser.add_argument('--split-size',
type=int,
default=1e10,
default=int(1e10),
help='Split size (default: whatever iterator was set to)')

subparser.add_argument('--num-samples',
Expand Down Expand Up @@ -107,6 +107,10 @@ def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argpar
def split(batch, split_size: int):
sequence_length = batch['source']['tokens'].shape[1]
num_splits = sequence_length // split_size
if not ((sequence_length % split_size) == 0):
num_splits += 1
else:
logger.warning('Perfect split')

def _chunk(x, start, stop):
if isinstance(x, dict):
Expand Down Expand Up @@ -156,7 +160,7 @@ def evaluate_perplexity(model: Model,
temperature: float = 1.0,
offset: bool = False,
samples_per_batch: int = 1,
split_size: int = 1e10) -> Dict[str, Any]:
split_size: int = int(1e10)) -> Dict[str, Any]:

check_for_gpu(cuda_device)
logger.info('Iterating over dataset')
Expand Down Expand Up @@ -193,14 +197,17 @@ def evaluate_perplexity(model: Model,

for i in range(num_samples // samples_per_batch):

logger.info(f'i={i}')

# summand = util.move_to_device(summand, cuda_device)
# batch = util.move_to_device(batch, cuda_device)

weights = None
for chunk in split(batch, split_size):
logger.info('next_chunk')
for j, chunk in enumerate(split(batch, split_size)):
generator_tqdm.set_description(f"i={i} j={j}")

chunk_tokens = util.get_text_field_mask(batch['source']).int().sum().item()
if chunk_tokens == 0:
logger.debug('Zero chunk, skipping')
continue

# Draw a sample
with torch.no_grad():
Expand Down Expand Up @@ -294,7 +301,7 @@ def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]:
# deal with averaging the importance samples across splits ...
# if args.split_size is not None:
# iterator_params['split_size'] = args.split_size
iterator_params['split_size'] = 1e10
iterator_params['split_size'] = int(1e10)
iterator_params['truncate'] = False # TODO: Shouldn't need this anymore...
iterator = DataIterator.from_params(iterator_params)
iterator.index_with(model.vocab)
Expand Down
6 changes: 5 additions & 1 deletion kglm/models/kglm_disc.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def sample(self,
alias_copy_inds: torch.Tensor,
shortlist: Dict[str, torch.Tensor] = None,
temperature: float = 1.0,
offset: bool = False,
**kwargs) -> Dict[str, Any]: # **kwargs intended to eat the other fields if they are provided.
"""
Sampling annotations for the generative model. Note that unlike forward, this function
Expand All @@ -193,7 +194,10 @@ def sample(self,
# We encode the target tokens (**not** source) since the discriminitative model makes
# predictions on the current token, but the generative model expects labels for the
# **next** (e.g. target) token!
encoded, *_ = self._encode_source(target['tokens'])
if not offset:
encoded, *_ = self._encode_source(target['tokens'])
else:
encoded, *_ = self._encode_source(source['tokens'])
splits = [self.token_embedding_dim] + [self.entity_embedding_dim] * 2
encoded_token, encoded_head, encoded_relation = encoded.split(splits, dim=-1)

Expand Down

0 comments on commit 453189b

Please sign in to comment.