Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data split depending on eval params #169

Merged
merged 21 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 46 additions & 29 deletions src/lighteval/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ class DynamicBatchDataset(Dataset):
def __init__(
self,
requests: list,
dataset_splits: int,
num_dataset_splits: int,
):
"""
This dataset class uses dynamic batching to speed up the generation.
Each request is sorted by the length of the prompt + the length of the
continuation. Then, the dataset is split into dataset_splits splits.
continuation. Then, the dataset is split into num_dataset_splits splits.
The first split will contain the longest requests, the second split will
contain the second longest requests, etc. This allows us to use dynamic
batching by starting with a small batch size and doubling it for each
Expand All @@ -54,7 +54,7 @@ def __init__(

Args:
requests (List): A list of requests.
dataset_splits (int): The number of dataset splits.
num_dataset_splits (int): The number of dataset splits.
"""
# We make sure the requests contain the tokenized versions of their values
if any(r.tokenized_context is None for r in requests):
Expand All @@ -69,16 +69,24 @@ def __init__(

self.total_size = len(self.sorted_data)

if dataset_splits >= self.total_size:
self.num_dataset_splits, self.splits = self.init_split_limits(num_dataset_splits)

self.split_start, self.split_end = self.splits[0]
self.split_end = self.splits[0][1]
clefourrier marked this conversation as resolved.
Show resolved Hide resolved
clefourrier marked this conversation as resolved.
Show resolved Hide resolved

def init_split_limits(self, num_dataset_splits):
if num_dataset_splits >= self.total_size:
hlog_warn(
f"dataset_splits ({dataset_splits}) >= total_size ({self.total_size}), setting dataset_splits to 1"
f"num_dataset_splits ({num_dataset_splits}) >= total_size ({self.total_size}), setting num_dataset_splits to 1"
)
dataset_splits = 1
num_dataset_splits = 1

split_size = self.total_size // num_dataset_splits + 1
splits_indices = [
(ix * split_size, min((ix + 1) * split_size, self.total_size)) for ix in range(num_dataset_splits)
]

self.dataset_splits = dataset_splits
self.split_size = self.total_size // self.dataset_splits + 1
self.split_start = 0
self.split_end = min(self.split_start + self.split_size, self.total_size)
return num_dataset_splits, splits_indices

def get_original_order(self, new_arr: list) -> list:
"""
Expand Down Expand Up @@ -113,8 +121,8 @@ def get_split_start_end(self, split_id: int) -> tuple[int, int]:
Returns:
tuple: A tuple containing the start and end indices of the split.
"""
self.split_start = split_id * self.split_size
self.split_end = min(self.split_start + self.split_size, self.total_size)
self.split_start, self.split_end = self.splits[split_id]
self.split_end = self.splits[split_id][1]
clefourrier marked this conversation as resolved.
Show resolved Hide resolved
clefourrier marked this conversation as resolved.
Show resolved Hide resolved
return self.split_start, self.split_end

def splits_start_end_iterator(self) -> tuple[int, int]:
Expand All @@ -126,7 +134,7 @@ def splits_start_end_iterator(self) -> tuple[int, int]:
Yields:
tuple: A tuple containing the start and end indices of a split.
"""
for split_id in range(self.dataset_splits):
for split_id in range(self.num_dataset_splits):
yield self.get_split_start_end(split_id)

def __getitem__(self, index) -> Request:
Expand Down Expand Up @@ -204,6 +212,29 @@ def _sorting_criteria(self, request: LoglikelihoodSingleTokenRequest) -> int:


class GenerativeTaskDataset(DynamicBatchDataset):
def init_split_limits(self, num_dataset_splits):
if num_dataset_splits is not None:
hlog_warn(
"You cannot select the number of dataset splits for a generative evaluation at the moment. Automatically inferring."
)

all_sorting_criterion = [self._sorting_criteria(self.sorted_data[0])[:2]]
NathanHB marked this conversation as resolved.
Show resolved Hide resolved
splits_indices = [[0, None]]
for ix, req in enumerate(self.sorted_data):
current_sorting_criteria = self._sorting_criteria(req)
current_key = current_sorting_criteria[:2]
if current_key not in all_sorting_criterion:
all_sorting_criterion.append(current_key)
splits_indices[-1][1] = ix
splits_indices.append([ix, None])

# We add the last split
splits_indices[-1][1] = self.total_size

num_dataset_splits = len(splits_indices)
splits_indices = [tuple(e) for e in splits_indices]
return num_dataset_splits, splits_indices

def _sorting_criteria(self, request: GreedyUntilRequest) -> int:
"""
Collate function for generating batches.
Expand All @@ -219,10 +250,10 @@ def _sorting_criteria(self, request: GreedyUntilRequest) -> int:
# The generative task has no limit except the model context
if gen_length is None:
gen_length = 0
return -(len(toks) + gen_length)
return request.use_logits, request.stop_sequence, -(len(toks) + gen_length)


class GenerativeTaskDatasetNanotron(DynamicBatchDataset):
class GenerativeTaskDatasetNanotron(GenerativeTaskDataset):
def __getitem__(self, index) -> Request:
"""
Get an item from the dataset depending on the split we are currently in.
Expand All @@ -238,20 +269,6 @@ def __getitem__(self, index) -> Request:
"""
return index, self.sorted_data[index + self.split_start]

def _sorting_criteria(self, request) -> int:
"""
Collate function for generating batches.

Args:
x (Any): The input data.

Returns:
Any: The collated data.
"""
toks = request.tokenized_context
gen_length = request.generation_size
return -(len(toks) + gen_length)


class GenDistributedSampler(DistributedSampler):
"""A distributed sampler that copy the last element only when drop_last is False so we keep a small padding in the batches
Expand Down
16 changes: 6 additions & 10 deletions src/lighteval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,22 +95,18 @@ def apply_generative_metric(
preds = [formatted_doc.specific["label_to_choices"].get(p) for p in preds]
golds = [formatted_doc.specific["label_to_choices"][g] for g in golds]

preds_no_sampling = preds
if max_num_samples > 1: # We want to run our evaluation on only one sample for base generative evals
preds_no_sampling = as_list(preds[0])
clefourrier marked this conversation as resolved.
Show resolved Hide resolved

for metric in metrics:
if Metrics[metric].value.category == MetricCategory.GENERATIVE:
outputs.update(
Metrics[metric].value.compute(
golds=golds,
predictions=as_list(preds[0]) if max_num_samples > 0 else preds,
formatted_doc=formatted_doc,
)
Metrics[metric].value.compute(golds=golds, predictions=preds_no_sampling, formatted_doc=formatted_doc)
)
if Metrics[metric].value.category == MetricCategory.GENERATIVE_LOGPROB:
outputs.update(
Metrics[metric].value.compute(
golds=golds,
predictions=as_list(preds[0]) if max_num_samples > 0 else preds,
formatted_doc=formatted_doc,
)
Metrics[metric].value.compute(golds=golds, predictions=preds_no_sampling, formatted_doc=formatted_doc)
)
if Metrics[metric].value.category == MetricCategory.GENERATIVE_SAMPLING:
outputs.update(Metrics[metric].value.compute(golds=golds, predictions=preds, formatted_doc=formatted_doc))
Expand Down
10 changes: 5 additions & 5 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def greedy_until_multi_turn( # noqa: C901

results = []

dataset = GenerativeTaskDataset(requests=requests, dataset_splits=1)
dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=1)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=lambda batch: batch)

if self.accelerator:
Expand Down Expand Up @@ -480,13 +480,13 @@ def greedy_until(
request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token]
request.tokenized_context = self.tok_encode(request.context)

dataset = GenerativeTaskDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
starting_batch_size = STARTING_BATCH_SIZE
results = []

for split_start, split_end in tqdm(
dataset.splits_start_end_iterator(),
total=self.DATASET_SPLITS,
total=dataset.num_dataset_splits,
desc="Splits",
position=0,
disable=self.disable_tqdm,
Expand Down Expand Up @@ -715,7 +715,7 @@ def _loglikelihood_tokens(
return_bool_score: bool = True,
rolling: bool = False,
) -> list[LoglikelihoodReturn]:
dataset = LoglikelihoodDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
dataset = LoglikelihoodDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
starting_batch_size = STARTING_BATCH_SIZE
res = []

Expand Down Expand Up @@ -957,7 +957,7 @@ def loglikelihood_single_token(
def _loglikelihood_single_token(
self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: int = -1
) -> list[LoglikelihoodSingleTokenReturn]:
dataset = LoglikelihoodSingleTokenDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
dataset = LoglikelihoodSingleTokenDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
starting_batch_size = STARTING_BATCH_SIZE
res = []

Expand Down
6 changes: 3 additions & 3 deletions src/lighteval/models/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def greedy_until(
request.tokenized_context = self.tok_encode(request.context)
request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token]

dataset = GenerativeTaskDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
batch_size = override_bs if override_bs is not None else BATCH_SIZE
results: List[str] = []

Expand Down Expand Up @@ -289,7 +289,7 @@ def loglikelihood(
for request in requests:
request.tokenized_context = self.tok_encode(request.context)
request.tokenized_continuation = self.tok_encode(request.choice)
dataset = LoglikelihoodDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
dataset = LoglikelihoodDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
batch_size = override_bs if override_bs is not None else BATCH_SIZE
results: List[str] = []

Expand Down Expand Up @@ -335,7 +335,7 @@ def loglikelihood_rolling(
request.tokenized_context = [self.tokenizer.eos_token_id]
request.tokenized_continuation = self.tok_encode(request.context)

dataset = LoglikelihoodDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
dataset = LoglikelihoodDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
batch_size = override_bs if override_bs is not None else BATCH_SIZE
results: List[str] = []

Expand Down
20 changes: 10 additions & 10 deletions src/lighteval/models/nanotron_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,17 +643,17 @@ def pad_and_gather(self, output_tensor: torch.Tensor) -> Tuple[torch.Tensor, tor

return gathered_outputs, gathered_length

def _get_subsets(self, dataset, dataset_splits):
def _get_subsets(self, dataset, num_dataset_splits):
total_length = len(dataset)
subset_length = int(float(total_length) / float(dataset_splits)) + 1
subset_length = int(float(total_length) / float(num_dataset_splits)) + 1
if subset_length < self.parallel_context.dp_pg.size():
# We need at least one subset sample per DP process
subset_length = self.parallel_context.dp_pg.size()
return total_length, subset_length

@torch.inference_mode()
def _loglikelihood_single_token(
self, requests, disable_tqdm: bool = False, override_bs: int = -1, dataset_splits: int = 1
self, requests, disable_tqdm: bool = False, override_bs: int = -1, num_dataset_splits: int = 1
) -> List[LoglikelihoodSingleTokenReturn]:
dataset = LoglikelihoodSingleTokenDataset(requests=requests)
res = []
Expand All @@ -663,7 +663,7 @@ def _loglikelihood_single_token(
printed_error = False
starting_batch_size = 512

total_length, subset_length = self._get_subsets(dataset, dataset_splits)
total_length, subset_length = self._get_subsets(dataset, num_dataset_splits)

for s, subset_start in enumerate(
tqdm(
Expand Down Expand Up @@ -883,17 +883,17 @@ def _loglikelihood_tokens(
requests,
disable_tqdm: bool = False,
override_bs: int = -1,
dataset_splits: int = 1,
num_dataset_splits: int = 1,
return_bool_score: bool = True,
) -> List[LoglikelihoodReturn]:
dataset = LoglikelihoodDataset(requests=requests, dataset_splits=dataset_splits)
dataset = LoglikelihoodDataset(requests=requests, num_dataset_splits=num_dataset_splits)
res = []

# Dataset is sorted in descending size.
# every 20-25% of the dataset we try to double the batch size for speed up
starting_batch_size = 512

total_length, subset_length = self._get_subsets(dataset, dataset_splits)
total_length, subset_length = self._get_subsets(dataset, num_dataset_splits)

for s, subset_start in enumerate(
tqdm(
Expand Down Expand Up @@ -1117,7 +1117,7 @@ def greedy_until(
requests: List[GreedyUntilRequest],
disable_tqdm: bool = False,
override_bs=None,
dataset_splits: int = 1,
num_dataset_splits: int = 1,
) -> List[GenerateReturn]:
"""Greedy generation until a stop token is generated."""
# automatic (variable) batch size detection for vectorization
Expand All @@ -1126,14 +1126,14 @@ def greedy_until(
request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token]
request.tokenized_context = self.tok_encode(request.context)

dataset = GenerativeTaskDatasetNanotron(requests=requests, dataset_splits=dataset_splits)
dataset = GenerativeTaskDatasetNanotron(requests=requests, num_dataset_splits=num_dataset_splits)
res = []

# Dataset is sorted in descending size.
# every 20-25% of the dataset we try to double the batch size for speed up
starting_batch_size = 512

total_length, subset_length = self._get_subsets(dataset, dataset_splits)
total_length, subset_length = self._get_subsets(dataset, num_dataset_splits)

for s, subset_start in enumerate(
tqdm(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_unit_reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@
class TestReorderGenerativeTaskDataset:
def test_dataset_needs_tokenization(self):
with pytest.raises(ValueError):
GenerativeTaskDataset(requests=TEST_DATA, dataset_splits=DATASET_SPLITS)
GenerativeTaskDataset(requests=TEST_DATA, num_dataset_splits=DATASET_SPLITS)

def test_reorder_dataset(self):
tokenizer = AutoTokenizer.from_pretrained("gpt2")
data = TEST_DATA.copy()
for request in data:
request.tokenized_context = tokenizer.encode(request.context)

dataset = GenerativeTaskDataset(requests=data, dataset_splits=DATASET_SPLITS)
dataset = GenerativeTaskDataset(requests=data, num_dataset_splits=DATASET_SPLITS)

sorted_data = dataset.sorted_data
original_data = dataset.get_original_order(sorted_data)
Expand Down
Loading