Skip to content

Commit

Permalink
Merge branch 'main' into fix-brrr
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Feb 7, 2024
2 parents e47bad9 + 1e837a9 commit e93fb58
Show file tree
Hide file tree
Showing 21 changed files with 967 additions and 435 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ keywords = ["evaluation", "nlp", "llm"]
dependencies = [
# Base dependencies
"transformers>=4.36.0",
"huggingface_hub==0.19.4",
"huggingface_hub==0.20.3",
"torch>=2.0",
"GitPython==3.1.31", # for logging
"datasets>=2.14.0",
Expand Down
40 changes: 27 additions & 13 deletions run_evals_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
def get_parser():
parser = argparse.ArgumentParser()
group = parser.add_mutually_exclusive_group(required=True)
weight_type_group = parser.add_mutually_exclusive_group()
task_type_group = parser.add_mutually_exclusive_group(required=True)

# Model type 1) Base model
weight_type_group = parser.add_mutually_exclusive_group()
weight_type_group.add_argument(
"--delta_weights",
action="store_true",
Expand All @@ -24,39 +26,50 @@ def get_parser():
"--base_model", type=str, default=None, help="name of the base model to be used for delta or adapter weights"
)

parser.add_argument("--model_args", required=True)
parser.add_argument("--output_dir", required=True)
task_type_group.add_argument("--model_args")
parser.add_argument("--model_dtype", type=str, default=None)
parser.add_argument(
"--multichoice_continuations_start_space",
action="store_true",
help="Whether to force multiple choice continuations starts with a space",
help="Whether to force multiple choice continuations to start with a space",
)
parser.add_argument(
"--no_multichoice_continuations_start_space",
action="store_true",
help="Whether to force multiple choice continuations do not starts with a space",
help="Whether to force multiple choice continuations to not start with a space",
)
parser.add_argument("--use_chat_template", default=False, action="store_true")
# Model type 2) TGI
task_type_group.add_argument("--inference_server_address", type=str)
parser.add_argument("--inference_server_auth", type=str, default=None)
# Model type 3) Inference endpoints
task_type_group.add_argument("--endpoint_model_name", type=str)
parser.add_argument("--accelerator", type=str, default=None)
parser.add_argument("--vendor", type=str, default=None)
parser.add_argument("--region", type=str, default=None)
parser.add_argument("--instance_size", type=str, default=None)
parser.add_argument("--instance_type", type=str, default=None)
parser.add_argument("--reuse_existing", default=False, action="store_true")
# Debug
parser.add_argument("--max_samples", type=int, default=None)
parser.add_argument("--job_id", type=str, help="Optional Job ID for future reference", default="")
# Saving
parser.add_argument("--push_results_to_hub", default=False, action="store_true")
parser.add_argument("--save_details", action="store_true")
parser.add_argument("--push_details_to_hub", default=False, action="store_true")
parser.add_argument(
"--public_run", default=False, action="store_true", help="Push results and details to a public repo"
)
parser.add_argument("--max_samples", type=int, default=None)
parser.add_argument("--override_batch_size", type=int, default=-1)
parser.add_argument("--dataset_loading_processes", type=int, default=1)
parser.add_argument("--inference_server_address", type=str, default=None)
parser.add_argument("--inference_server_auth", type=str, default=None)
parser.add_argument("--num_fewshot_seeds", type=int, default=1, help="Number of trials the few shots")
parser.add_argument("--cache_dir", type=str, default=CACHE_DIR)
parser.add_argument(
"--results_org",
type=str,
help="Hub organisation where you want to store the results. Your current token must have write access to it",
)
parser.add_argument("--job_id", type=str, help="Optional Job ID for future reference", default="")
parser.add_argument("--use_chat_template", default=False, action="store_true")
# Common parameters
parser.add_argument("--output_dir", required=True)
parser.add_argument("--override_batch_size", type=int, default=-1)
parser.add_argument("--dataset_loading_processes", type=int, default=1)
parser.add_argument(
"--custom_tasks_file",
type=str,
Expand All @@ -69,6 +82,7 @@ def get_parser():
default=None,
help="Id of a task, e.g. 'original|mmlu:abstract_algebra|5' or path to a texte file with a list of tasks",
)
parser.add_argument("--num_fewshot_seeds", type=int, default=1, help="Number of trials the few shots")
return parser


Expand Down
44 changes: 23 additions & 21 deletions src/lighteval/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
from torch.utils.data.distributed import DistributedSampler, T_co

from lighteval.logging.hierarchical_logger import hlog_warn
from lighteval.tasks.requests import Request
from lighteval.tasks.requests import (
GreedyUntilRequest,
GreedyUntilWithLogitsRequest,
LoglikelihoodRequest,
LoglikelihoodRollingRequest,
LoglikelihoodSingleTokenRequest,
Request,
)


class DynamicBatchDataset(Dataset):
Expand All @@ -28,6 +35,9 @@ def __init__(
requests (List): A list of requests.
dataset_splits (int): The number of dataset splits.
"""
# We make sure the requests contain the tokenized versions of their values
if any(r.tokenized_context is None for r in requests):
raise ValueError("You passed a request for which tokenization had not happened yet.")

# sort the requests using the collate function and save the original order
enumerated_requests = list(enumerate(requests))
Expand Down Expand Up @@ -124,12 +134,12 @@ def __len__(self) -> int:
"""
return self.split_end - self.split_start

def _sorting_criteria(self, x) -> int:
def _sorting_criteria(self, request) -> int:
raise NotImplementedError()


class LoglikelihoodDataset(DynamicBatchDataset):
def _sorting_criteria(self, x) -> int:
def _sorting_criteria(self, request: LoglikelihoodRequest | LoglikelihoodRollingRequest) -> int:
"""
Collates the input data for batching.
Expand All @@ -149,13 +159,12 @@ def _sorting_criteria(self, x) -> int:
Returns:
tuple: A tuple containing the sorted input data.
"""

toks = x[1] + x[2]
toks = request.tokenized_context + request.tokenized_continuation
return -len(toks)


class LoglikelihoodSingleTokenDataset(DynamicBatchDataset):
def _sorting_criteria(self, x) -> int:
def _sorting_criteria(self, request: LoglikelihoodSingleTokenRequest) -> int:
"""
Collates the input data for batching.
Expand All @@ -167,19 +176,14 @@ def _sorting_criteria(self, x) -> int:
is useful to simplify the batching logic and more importantly to make
automatic adaptive batches much much easier to implement
- any OOMs will happen right away rather than near the end
Args:
x (tuple): A tuple containing the input data.
Returns:
tuple: A tuple containing the collated data.
"""
toks = x[1] # We take only the prompt, no need for the continuation (since it's a list of single tokens)
# We take only the prompt, no need for the continuation (since it's a list of single tokens)
toks = request.tokenized_context
return -len(toks)


class GenerativeTaskDataset(DynamicBatchDataset):
def _sorting_criteria(self, x) -> int:
def _sorting_criteria(self, request: GreedyUntilRequest | GreedyUntilWithLogitsRequest) -> int:
"""
Collate function for generating batches.
Expand All @@ -189,9 +193,8 @@ def _sorting_criteria(self, x) -> int:
Returns:
Any: The collated data.
"""
toks = x[0]
meta_data = x[1]
_, gen_length = meta_data[0], meta_data[1]
toks = request.tokenized_context
gen_length = request.generation_size
return -(len(toks) + gen_length)


Expand All @@ -211,7 +214,7 @@ def __getitem__(self, index) -> Request:
"""
return index, self.sorted_data[index + self.split_start]

def _sorting_criteria(self, x) -> int:
def _sorting_criteria(self, request) -> int:
"""
Collate function for generating batches.
Expand All @@ -221,9 +224,8 @@ def _sorting_criteria(self, x) -> int:
Returns:
Any: The collated data.
"""
toks = x[0]
meta_data = x[1]
_, gen_length = meta_data[0], meta_data[1]
toks = request.tokenized_context
gen_length = request.generation_size
return -(len(toks) + gen_length)


Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.logging.hierarchical_logger import hlog
from lighteval.models.base_model import BaseModel
from lighteval.models.inference_client import ModelClient
from lighteval.models.tgi_model import ModelClient
from lighteval.tasks.lighteval_task import LightevalTask
from lighteval.tasks.requests import Doc, Request, RequestType, TaskExampleId

Expand Down
2 changes: 2 additions & 0 deletions src/lighteval/main_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,6 @@ def main(args):

print(make_results_table(final_dict))

model.cleanup()

return final_dict
21 changes: 16 additions & 5 deletions src/lighteval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,18 @@

def apply_target_perplexity_metric(results: list[ModelReturn], formatted_doc: Doc, metrics: list[str]):
outputs = {}
current_results = [results.pop(0) for _ in range(len(formatted_doc.get_golds()))]
reference_text = formatted_doc.get_golds()[0]
current_result = results.pop(0)
target_logprob = current_result.result[0]
target_acc = current_result.result[1]

for metric in metrics:
if Metrics[metric].value.category == MetricCategory.PERPLEXITY:
outputs.update(Metrics[metric].value.compute(results=current_results))
if Metrics[metric].value.category == MetricCategory.TARGET_PERPLEXITY:
outputs.update(
Metrics[metric].value.compute(
logprobs=target_logprob, target_acc=target_acc, reference_text=reference_text
)
)

return results, outputs

Expand All @@ -30,7 +37,9 @@ def apply_perplexity_metric(results: list[ModelReturn], formatted_doc: Doc, metr

for metric in metrics:
if Metrics[metric].value.category == MetricCategory.PERPLEXITY:
outputs.update(Metrics[metric].value.compute(results=current_result, reference_text=reference_text))
outputs.update(
Metrics[metric].value.compute(logprobs=current_result.result, reference_text=reference_text)
)

return results, outputs

Expand Down Expand Up @@ -85,7 +94,9 @@ def apply_multichoice_metric(results: list[ModelReturn], formatted_doc: Doc, met
raise ValueError(
"You can't use a multi choice metric with only one choice. Use `acc_golds_likelihood` instead."
)
choices_logprob = [results[i].result[0] for i in range(len(formatted_doc.choices))]

# Todo: make better system with return_bool_score instead of taking first element
choices_logprob = [results[i].result[0] for i in range(len(formatted_doc.choices))] # sum(
gold_ixs = as_list(formatted_doc.gold_index)

for metric in metrics:
Expand Down
9 changes: 5 additions & 4 deletions src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""This module manages all the metrics occurring at the sample level. The results of said metrics are then aggregated
using simple function (min, mean, max, ...) at the corpus level. Most metrics fall under this category.
"""
from typing import Union

import nltk
import numpy as np
from nltk.metrics.distance import edit_distance
Expand Down Expand Up @@ -275,17 +277,16 @@ def compute(self, choices_logprob: list[float], gold_ixs: list[float], formatted
return 1.0 / (min(ranked_choices) + 1)


def acc_golds_likelihood(results: list[tuple[float, int]], **kwargs) -> int:
def acc_golds_likelihood(target_acc: Union[list[int], int], **kwargs) -> int:
"""Tests if at least one of predicted gold targets' log-likelihood is above 0.5.
Args:
results (list[int]): List of tuples containing, for each gold, the predictions log-probabilities associated with whether they are above 0.5 aggregated.
formatted_doc (Doc): _description_
target_acc (list[int]): List of scores indicating whether the predictions log-probabilities are above 0.5 aggregated.
Returns:
int: 1 if at least one of the possible golds had a log-likelihood above 0.5.
"""
return max([int(acc_ppl) for _, acc_ppl in results])
return max([int(acc_ppl) for acc_ppl in as_list(target_acc)])


class ROUGE:
Expand Down
6 changes: 3 additions & 3 deletions src/lighteval/metrics/sample_preparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,14 @@ def count_units(self, text: str) -> int:
if self.units_type == "bytes":
return len(text.encode("utf-8"))

def prepare(self, results, reference_text, **kwargs):
def prepare(self, logprobs: list[float] | float, reference_text: str, **kwargs):
"""Prepares an individual perplexity example to the format expected by metrics computed at the corpus level (aggregated).
Args:
results (list[float]): List of the logprobabilities computed for each item
logprobs (list[float]): List of the logprobabilities computed for each item of the sequence or single aggregated logprob over the sequence
reference_text (str): Current reference text for which to compute the length in self.units_type
Returns:
PerplexityCorpusMetricInput: Stores the measured logprobs and associated text lengths, counted in the reference unit.
"""
return PerplexityCorpusMetricInput(logprobs=results.result, weights=self.count_units(reference_text))
return PerplexityCorpusMetricInput(logprobs=logprobs, weights=self.count_units(reference_text))
Loading

0 comments on commit e93fb58

Please sign in to comment.