diff --git a/kglm/commands/evaluate_perplexity.py b/kglm/commands/evaluate_perplexity.py index e71caa6..6d0fe25 100644 --- a/kglm/commands/evaluate_perplexity.py +++ b/kglm/commands/evaluate_perplexity.py @@ -58,7 +58,7 @@ def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argpar subparser.add_argument('--split-size', type=int, - default=1e10, + default=int(1e10), help='Split size (default: whatever iterator was set to)') subparser.add_argument('--num-samples', @@ -107,6 +107,10 @@ def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argpar def split(batch, split_size: int): sequence_length = batch['source']['tokens'].shape[1] num_splits = sequence_length // split_size + if not ((sequence_length % split_size) == 0): + num_splits += 1 + else: + logger.warning('Perfect split') def _chunk(x, start, stop): if isinstance(x, dict): @@ -156,7 +160,7 @@ def evaluate_perplexity(model: Model, temperature: float = 1.0, offset: bool = False, samples_per_batch: int = 1, - split_size: int = 1e10) -> Dict[str, Any]: + split_size: int = int(1e10)) -> Dict[str, Any]: check_for_gpu(cuda_device) logger.info('Iterating over dataset') @@ -193,14 +197,17 @@ def evaluate_perplexity(model: Model, for i in range(num_samples // samples_per_batch): - logger.info(f'i={i}') - # summand = util.move_to_device(summand, cuda_device) # batch = util.move_to_device(batch, cuda_device) weights = None - for chunk in split(batch, split_size): - logger.info('next_chunk') + for j, chunk in enumerate(split(batch, split_size)): + generator_tqdm.set_description(f"i={i} j={j}") + + chunk_tokens = util.get_text_field_mask(batch['source']).int().sum().item() + if chunk_tokens == 0: + logger.debug('Zero chunk, skipping') + continue # Draw a sample with torch.no_grad(): @@ -294,7 +301,7 @@ def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]: # deal with averaging the importance samples across splits ... # if args.split_size is not None: # iterator_params['split_size'] = args.split_size - iterator_params['split_size'] = 1e10 + iterator_params['split_size'] = int(1e10) iterator_params['truncate'] = False # TODO: Shouldn't need this anymore... iterator = DataIterator.from_params(iterator_params) iterator.index_with(model.vocab) diff --git a/kglm/models/kglm_disc.py b/kglm/models/kglm_disc.py index aa24148..ef81587 100644 --- a/kglm/models/kglm_disc.py +++ b/kglm/models/kglm_disc.py @@ -175,6 +175,7 @@ def sample(self, alias_copy_inds: torch.Tensor, shortlist: Dict[str, torch.Tensor] = None, temperature: float = 1.0, + offset: bool = False, **kwargs) -> Dict[str, Any]: # **kwargs intended to eat the other fields if they are provided. """ Sampling annotations for the generative model. Note that unlike forward, this function @@ -193,7 +194,10 @@ def sample(self, # We encode the target tokens (**not** source) since the discriminitative model makes # predictions on the current token, but the generative model expects labels for the # **next** (e.g. target) token! - encoded, *_ = self._encode_source(target['tokens']) + if not offset: + encoded, *_ = self._encode_source(target['tokens']) + else: + encoded, *_ = self._encode_source(source['tokens']) splits = [self.token_embedding_dim] + [self.entity_embedding_dim] * 2 encoded_token, encoded_head, encoded_relation = encoded.split(splits, dim=-1)