diff --git a/src/lighteval/data.py b/src/lighteval/data.py index 247cff042..22b68bd6a 100644 --- a/src/lighteval/data.py +++ b/src/lighteval/data.py @@ -40,12 +40,12 @@ class DynamicBatchDataset(Dataset): def __init__( self, requests: list, - dataset_splits: int, + num_dataset_splits: int, ): """ This dataset class uses dynamic batching to speed up the generation. Each request is sorted by the length of the prompt + the length of the - continuation. Then, the dataset is split into dataset_splits splits. + continuation. Then, the dataset is split into num_dataset_splits splits. The first split will contain the longest requests, the second split will contain the second longest requests, etc. This allows us to use dynamic batching by starting with a small batch size and doubling it for each @@ -54,7 +54,7 @@ def __init__( Args: requests (List): A list of requests. - dataset_splits (int): The number of dataset splits. + num_dataset_splits (int): The number of dataset splits. """ # We make sure the requests contain the tokenized versions of their values if any(r.tokenized_context is None for r in requests): @@ -69,16 +69,23 @@ def __init__( self.total_size = len(self.sorted_data) - if dataset_splits >= self.total_size: + self.num_dataset_splits, self.splits = self.init_split_limits(num_dataset_splits) + + self.split_start, self.split_end = self.splits[0] + + def init_split_limits(self, num_dataset_splits): + if num_dataset_splits >= self.total_size: hlog_warn( - f"dataset_splits ({dataset_splits}) >= total_size ({self.total_size}), setting dataset_splits to 1" + f"num_dataset_splits ({num_dataset_splits}) >= total_size ({self.total_size}), setting num_dataset_splits to 1" ) - dataset_splits = 1 + num_dataset_splits = 1 - self.dataset_splits = dataset_splits - self.split_size = self.total_size // self.dataset_splits + 1 - self.split_start = 0 - self.split_end = min(self.split_start + self.split_size, self.total_size) + split_size = self.total_size // num_dataset_splits + 1 + splits_indices = [ + (ix * split_size, min((ix + 1) * split_size, self.total_size)) for ix in range(num_dataset_splits) + ] + + return num_dataset_splits, splits_indices def get_original_order(self, new_arr: list) -> list: """ @@ -113,8 +120,7 @@ def get_split_start_end(self, split_id: int) -> tuple[int, int]: Returns: tuple: A tuple containing the start and end indices of the split. """ - self.split_start = split_id * self.split_size - self.split_end = min(self.split_start + self.split_size, self.total_size) + self.split_start, self.split_end = self.splits[split_id] return self.split_start, self.split_end def splits_start_end_iterator(self) -> tuple[int, int]: @@ -126,7 +132,7 @@ def splits_start_end_iterator(self) -> tuple[int, int]: Yields: tuple: A tuple containing the start and end indices of a split. """ - for split_id in range(self.dataset_splits): + for split_id in range(self.num_dataset_splits): yield self.get_split_start_end(split_id) def __getitem__(self, index) -> Request: @@ -204,7 +210,47 @@ def _sorting_criteria(self, request: LoglikelihoodSingleTokenRequest) -> int: class GenerativeTaskDataset(DynamicBatchDataset): - def _sorting_criteria(self, request: GreedyUntilRequest) -> int: + def init_split_limits(self, num_dataset_splits): + """Initialises the split limits based on generation parameters. + The splits are used to estimate time remaining when evaluating, and in the case of generative evaluations, to group similar samples together. + + For generative tasks, self._sorting_criteria outputs: + - a boolean (whether the generation task uses logits) + - a list (the stop sequences) + - the item length (the actual size sorting factor). + + In the current function, we create evaluation groups by generation parameters (logits and eos), so that samples with similar properties get batched together afterwards. + The samples will then be further organised by length in each split. + + Args: + num_dataset_splits (_type_): _description_ + + Returns: + _type_: _description_ + """ + if num_dataset_splits is not None: + hlog_warn( + "You cannot select the number of dataset splits for a generative evaluation at the moment. Automatically inferring." + ) + + all_sorting_criterion = [self._sorting_criteria(self.sorted_data[0])[:2]] + splits_indices = [[0, None]] + for ix, req in enumerate(self.sorted_data): + current_sorting_criteria = self._sorting_criteria(req) + current_key = current_sorting_criteria[:2] + if current_key not in all_sorting_criterion: + all_sorting_criterion.append(current_key) + splits_indices[-1][1] = ix + splits_indices.append([ix, None]) + + # We add the last split + splits_indices[-1][1] = self.total_size + + num_dataset_splits = len(splits_indices) + splits_indices = [tuple(e) for e in splits_indices] + return num_dataset_splits, splits_indices + + def _sorting_criteria(self, request: GreedyUntilRequest) -> tuple[bool, list, int]: """ Collate function for generating batches. @@ -219,10 +265,10 @@ def _sorting_criteria(self, request: GreedyUntilRequest) -> int: # The generative task has no limit except the model context if gen_length is None: gen_length = 0 - return -(len(toks) + gen_length) + return request.use_logits, request.stop_sequence, -(len(toks) + gen_length) -class GenerativeTaskDatasetNanotron(DynamicBatchDataset): +class GenerativeTaskDatasetNanotron(GenerativeTaskDataset): def __getitem__(self, index) -> Request: """ Get an item from the dataset depending on the split we are currently in. @@ -238,20 +284,6 @@ def __getitem__(self, index) -> Request: """ return index, self.sorted_data[index + self.split_start] - def _sorting_criteria(self, request) -> int: - """ - Collate function for generating batches. - - Args: - x (Any): The input data. - - Returns: - Any: The collated data. - """ - toks = request.tokenized_context - gen_length = request.generation_size - return -(len(toks) + gen_length) - class GenDistributedSampler(DistributedSampler): """A distributed sampler that copy the last element only when drop_last is False so we keep a small padding in the batches diff --git a/src/lighteval/metrics/__init__.py b/src/lighteval/metrics/__init__.py index 3d5257562..b7f9183a8 100644 --- a/src/lighteval/metrics/__init__.py +++ b/src/lighteval/metrics/__init__.py @@ -100,7 +100,7 @@ def apply_generative_metric( outputs.update( Metrics[metric].value.compute( golds=golds, - predictions=as_list(preds[0]) if max_num_samples > 0 else preds, + predictions=as_list(preds[0]) if max_num_samples > 1 else preds, formatted_doc=formatted_doc, ) ) @@ -108,7 +108,7 @@ def apply_generative_metric( outputs.update( Metrics[metric].value.compute( golds=golds, - predictions=as_list(preds[0]) if max_num_samples > 0 else preds, + predictions=as_list(preds[0]) if max_num_samples > 1 else preds, formatted_doc=formatted_doc, ) ) diff --git a/src/lighteval/models/base_model.py b/src/lighteval/models/base_model.py index 7f17f24d9..3e483d448 100644 --- a/src/lighteval/models/base_model.py +++ b/src/lighteval/models/base_model.py @@ -335,7 +335,7 @@ def greedy_until_multi_turn( # noqa: C901 results = [] - dataset = GenerativeTaskDataset(requests=requests, dataset_splits=1) + dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=1) dataloader = DataLoader(dataset, batch_size=1, collate_fn=lambda batch: batch) if self.accelerator: @@ -480,13 +480,13 @@ def greedy_until( request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token] request.tokenized_context = self.tok_encode(request.context) - dataset = GenerativeTaskDataset(requests=requests, dataset_splits=self.DATASET_SPLITS) + dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS) starting_batch_size = STARTING_BATCH_SIZE results = [] for split_start, split_end in tqdm( dataset.splits_start_end_iterator(), - total=self.DATASET_SPLITS, + total=dataset.num_dataset_splits, desc="Splits", position=0, disable=self.disable_tqdm, @@ -715,7 +715,7 @@ def _loglikelihood_tokens( return_bool_score: bool = True, rolling: bool = False, ) -> list[LoglikelihoodReturn]: - dataset = LoglikelihoodDataset(requests=requests, dataset_splits=self.DATASET_SPLITS) + dataset = LoglikelihoodDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS) starting_batch_size = STARTING_BATCH_SIZE res = [] @@ -957,7 +957,7 @@ def loglikelihood_single_token( def _loglikelihood_single_token( self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: int = -1 ) -> list[LoglikelihoodSingleTokenReturn]: - dataset = LoglikelihoodSingleTokenDataset(requests=requests, dataset_splits=self.DATASET_SPLITS) + dataset = LoglikelihoodSingleTokenDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS) starting_batch_size = STARTING_BATCH_SIZE res = [] diff --git a/src/lighteval/models/endpoint_model.py b/src/lighteval/models/endpoint_model.py index b7e9af31b..83f42916c 100644 --- a/src/lighteval/models/endpoint_model.py +++ b/src/lighteval/models/endpoint_model.py @@ -243,7 +243,7 @@ def greedy_until( request.tokenized_context = self.tok_encode(request.context) request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token] - dataset = GenerativeTaskDataset(requests=requests, dataset_splits=self.DATASET_SPLITS) + dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS) batch_size = override_bs if override_bs is not None else BATCH_SIZE results: List[str] = [] @@ -289,7 +289,7 @@ def loglikelihood( for request in requests: request.tokenized_context = self.tok_encode(request.context) request.tokenized_continuation = self.tok_encode(request.choice) - dataset = LoglikelihoodDataset(requests=requests, dataset_splits=self.DATASET_SPLITS) + dataset = LoglikelihoodDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS) batch_size = override_bs if override_bs is not None else BATCH_SIZE results: List[str] = [] @@ -335,7 +335,7 @@ def loglikelihood_rolling( request.tokenized_context = [self.tokenizer.eos_token_id] request.tokenized_continuation = self.tok_encode(request.context) - dataset = LoglikelihoodDataset(requests=requests, dataset_splits=self.DATASET_SPLITS) + dataset = LoglikelihoodDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS) batch_size = override_bs if override_bs is not None else BATCH_SIZE results: List[str] = [] diff --git a/src/lighteval/models/nanotron_model.py b/src/lighteval/models/nanotron_model.py index b75bc2b27..560b29b35 100644 --- a/src/lighteval/models/nanotron_model.py +++ b/src/lighteval/models/nanotron_model.py @@ -643,9 +643,9 @@ def pad_and_gather(self, output_tensor: torch.Tensor) -> Tuple[torch.Tensor, tor return gathered_outputs, gathered_length - def _get_subsets(self, dataset, dataset_splits): + def _get_subsets(self, dataset, num_dataset_splits): total_length = len(dataset) - subset_length = int(float(total_length) / float(dataset_splits)) + 1 + subset_length = int(float(total_length) / float(num_dataset_splits)) + 1 if subset_length < self.parallel_context.dp_pg.size(): # We need at least one subset sample per DP process subset_length = self.parallel_context.dp_pg.size() @@ -653,7 +653,7 @@ def _get_subsets(self, dataset, dataset_splits): @torch.inference_mode() def _loglikelihood_single_token( - self, requests, disable_tqdm: bool = False, override_bs: int = -1, dataset_splits: int = 1 + self, requests, disable_tqdm: bool = False, override_bs: int = -1, num_dataset_splits: int = 1 ) -> List[LoglikelihoodSingleTokenReturn]: dataset = LoglikelihoodSingleTokenDataset(requests=requests) res = [] @@ -663,7 +663,7 @@ def _loglikelihood_single_token( printed_error = False starting_batch_size = 512 - total_length, subset_length = self._get_subsets(dataset, dataset_splits) + total_length, subset_length = self._get_subsets(dataset, num_dataset_splits) for s, subset_start in enumerate( tqdm( @@ -883,17 +883,17 @@ def _loglikelihood_tokens( requests, disable_tqdm: bool = False, override_bs: int = -1, - dataset_splits: int = 1, + num_dataset_splits: int = 1, return_bool_score: bool = True, ) -> List[LoglikelihoodReturn]: - dataset = LoglikelihoodDataset(requests=requests, dataset_splits=dataset_splits) + dataset = LoglikelihoodDataset(requests=requests, num_dataset_splits=num_dataset_splits) res = [] # Dataset is sorted in descending size. # every 20-25% of the dataset we try to double the batch size for speed up starting_batch_size = 512 - total_length, subset_length = self._get_subsets(dataset, dataset_splits) + total_length, subset_length = self._get_subsets(dataset, num_dataset_splits) for s, subset_start in enumerate( tqdm( @@ -1117,7 +1117,7 @@ def greedy_until( requests: List[GreedyUntilRequest], disable_tqdm: bool = False, override_bs=None, - dataset_splits: int = 1, + num_dataset_splits: int = 1, ) -> List[GenerateReturn]: """Greedy generation until a stop token is generated.""" # automatic (variable) batch size detection for vectorization @@ -1126,14 +1126,14 @@ def greedy_until( request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token] request.tokenized_context = self.tok_encode(request.context) - dataset = GenerativeTaskDatasetNanotron(requests=requests, dataset_splits=dataset_splits) + dataset = GenerativeTaskDatasetNanotron(requests=requests, num_dataset_splits=num_dataset_splits) res = [] # Dataset is sorted in descending size. # every 20-25% of the dataset we try to double the batch size for speed up starting_batch_size = 512 - total_length, subset_length = self._get_subsets(dataset, dataset_splits) + total_length, subset_length = self._get_subsets(dataset, num_dataset_splits) for s, subset_start in enumerate( tqdm( diff --git a/tests/test_unit_reorder.py b/tests/test_unit_reorder.py index 1936bcc59..6487cd93f 100644 --- a/tests/test_unit_reorder.py +++ b/tests/test_unit_reorder.py @@ -77,7 +77,7 @@ class TestReorderGenerativeTaskDataset: def test_dataset_needs_tokenization(self): with pytest.raises(ValueError): - GenerativeTaskDataset(requests=TEST_DATA, dataset_splits=DATASET_SPLITS) + GenerativeTaskDataset(requests=TEST_DATA, num_dataset_splits=DATASET_SPLITS) def test_reorder_dataset(self): tokenizer = AutoTokenizer.from_pretrained("gpt2") @@ -85,7 +85,7 @@ def test_reorder_dataset(self): for request in data: request.tokenized_context = tokenizer.encode(request.context) - dataset = GenerativeTaskDataset(requests=data, dataset_splits=DATASET_SPLITS) + dataset = GenerativeTaskDataset(requests=data, num_dataset_splits=DATASET_SPLITS) sorted_data = dataset.sorted_data original_data = dataset.get_original_order(sorted_data)